Skip to content

Commit 5f497ce

Browse files
authored
[Main Update] Doc to messages feature support and Split simple and chat mode (#692)
* Update deps * Restructured * Delete models * Remove deprecated models * Set up auto doc to messages and chat models * Lint * Allow force simple mode * Add auto doc to messages for audio and video
1 parent 477b802 commit 5f497ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+286
-3670
lines changed

lmms_eval/__main__.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def parse_eval_args() -> argparse.Namespace:
265265
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
266266
)
267267
parser.add_argument("--process_with_media", action="store_true", help="Whether you will process you dataset with audio, image. By default set to False" "In case some benchmarks need to be processed with media, set this flag to True.")
268+
parser.add_argument("--force_simple", action="store_true", help="Force the evaluation to use the simple mode of the models")
268269
args = parser.parse_args()
269270
return args
270271

@@ -421,23 +422,6 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
421422
elif args.tasks == "list_subtasks":
422423
eval_logger.info(task_manager.list_all_tasks(list_groups=False, list_tags=False))
423424
sys.exit()
424-
elif args.tasks == "list_with_num":
425-
log_message = (
426-
"\n" + "=" * 70 + "\n" + "\n\tYou are trying to check all the numbers in each task." + "\n\tThis action will download the complete dataset." + "\n\tIf the results are not clear initially, call this again." + "\n\n" + "=" * 70
427-
)
428-
eval_logger.info(log_message)
429-
for task_name in sorted(task_manager.list_all_tasks()):
430-
try:
431-
task_dict = get_task_dict([task_name], model_name="llava")
432-
task_obj = task_dict[task_name]
433-
if type(task_obj) == tuple:
434-
group, task_obj = task_obj
435-
if task_obj is None:
436-
continue
437-
eval_logger.info(f"\nTask : {task_obj.config.task}\n - #num : {len(task_obj.test_docs()) if task_obj.has_test_docs() else len(task_obj.validation_docs())}")
438-
except Exception as e:
439-
eval_logger.debug(f"\nTask : {task_name} fail to load \n Exception : \n {e}")
440-
sys.exit()
441425
else:
442426
if os.path.isdir(args.tasks):
443427
import glob
@@ -496,6 +480,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
496480
fewshot_random_seed=args.seed[3],
497481
cli_args=args,
498482
datetime_str=datetime_str,
483+
force_simple=args.force_simple,
499484
**request_caching_args,
500485
)
501486

lmms_eval/api/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616

1717
class lmms(abc.ABC):
18+
is_simple: bool = True
19+
1820
def __init__(self) -> None:
1921
"""Defines the interface that should be implemented by all lmms subclasses.
2022
lmmss are assumed to take image-text as input and yield strings as output

lmms_eval/api/task.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from datasets import Audio, DownloadConfig, Image, Sequence
3333
from huggingface_hub import snapshot_download
3434
from loguru import logger as eval_logger
35+
from PIL import Image as PIL_Image
3536
from PIL import ImageFile
3637
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed
3738
from tqdm import tqdm
@@ -91,6 +92,7 @@ class TaskConfig(dict):
9192
doc_to_text: Union[Callable, str] = None
9293
doc_to_target: Union[Callable, str] = None
9394
doc_to_choice: Union[Callable, str, dict, list] = None
95+
doc_to_messages: Callable = None
9496
process_results: Union[Callable, str] = None
9597
use_prompt: str = None
9698
description: str = ""
@@ -952,7 +954,7 @@ def _download_from_youtube(path):
952954
)
953955
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
954956
tar_files = glob(os.path.join(cache_path, "**/*.tar*"), recursive=True)
955-
957+
956958
def unzip_video_data(zip_file):
957959
import os
958960
import zipfile
@@ -1634,3 +1636,45 @@ def task_name(self) -> Any:
16341636

16351637
def __repr__(self):
16361638
return f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," f"output_type={self.OUTPUT_TYPE}," f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," f"num_samples={len(self.eval_docs)})"
1639+
1640+
1641+
class ConfigurableMessagesTask(ConfigurableTask):
1642+
def doc_to_messages(self, doc: dict) -> Union[int, str, list]:
1643+
if callable(self.config.doc_to_messages):
1644+
return (
1645+
self.config.doc_to_messages(doc, self.lmms_eval_specific_kwargs)
1646+
if self.lmms_eval_specific_kwargs is not None and len(inspect.signature(self.config.doc_to_messages).parameters) == 2
1647+
else self.config.doc_to_messages(
1648+
doc,
1649+
)
1650+
)
1651+
elif self.config.doc_to_messages is None and self.config.doc_to_visual is not None and self.config.doc_to_text is not None:
1652+
# An auto doc to messages function
1653+
def auto_doc_to_messages(doc):
1654+
visuals = self.doc_to_visual(doc)
1655+
text = self.doc_to_text(doc)
1656+
messages = [{"role": "user", "content": []}]
1657+
content = []
1658+
for visual in visuals:
1659+
if isinstance(visual, PIL_Image.Image):
1660+
content.append({"type": "image", "url": visual})
1661+
elif isinstance(visual, dict):
1662+
content.append({"type": "audio", "url": visual})
1663+
elif isinstance(visual, str):
1664+
content.append({"type": "video", "url": visual})
1665+
content.append({"type": "text", "text": text})
1666+
messages[0]["content"] = content
1667+
return messages
1668+
1669+
return auto_doc_to_messages(doc)
1670+
else:
1671+
# eval_logger.warning("Note that doc_to_visual was called but not set in config. Please check if this is a text-only task.")
1672+
return self.config.doc_to_messages
1673+
1674+
def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]:
1675+
split = kwargs.get("metadata").get("split")
1676+
# kwargs.pop("split")
1677+
assert self.OUTPUT_TYPE == "generate_until", "Currently messages is used for generation only"
1678+
1679+
arguments = (self.doc_to_messages, copy.deepcopy(self.config.generation_kwargs), doc_id, self.config.task, split)
1680+
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)

lmms_eval/evaluator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def simple_evaluate(
7979
datetime_str: str = get_datetime_str(),
8080
distributed_executor_backend: str = "accelerate",
8181
cli_args=None,
82+
force_simple: bool = False,
8283
):
8384
"""Instantiate and evaluate a model on a list of tasks.
8485
@@ -172,12 +173,10 @@ def simple_evaluate(
172173
if task_manager is None:
173174
task_manager = TaskManager(verbosity, model_name=model)
174175

175-
task_dict = get_task_dict(tasks, task_manager)
176-
177176
if isinstance(model, str):
178177
if model_args is None:
179178
model_args = ""
180-
lm = lmms_eval.models.get_model(model).create_from_arg_string(
179+
lm = lmms_eval.models.get_model(model, force_simple).create_from_arg_string(
181180
model_args,
182181
{
183182
"batch_size": batch_size,
@@ -187,6 +186,8 @@ def simple_evaluate(
187186
)
188187
elif isinstance(model, lmms_eval.api.model.lmms):
189188
lm = model
189+
task_type = "simple" if lm.is_simple else "chat"
190+
task_dict = get_task_dict(tasks, task_manager, task_type)
190191

191192
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
192193
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
@@ -551,8 +552,7 @@ def evaluate(
551552
ensure_ascii=False,
552553
)
553554
),
554-
"prompt_hash": hash_string(requests[0].arguments[0]),
555-
"target_hash": hash_string(str(target)),
555+
# Removing prompt hash and target hash here
556556
}
557557
example.update(metrics)
558558
task_output.logged_samples.append(example)

lmms_eval/models/__init__.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib
22
import os
33
import sys
4+
from typing import Literal
45

56
import hf_transfer
67
from loguru import logger
@@ -12,7 +13,8 @@
1213
log_format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " "<level>{level: <8}</level> | " "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - " "<level>{message}</level>"
1314
logger.add(sys.stdout, level="WARNING", format=log_format)
1415

15-
AVAILABLE_MODELS = {
16+
17+
AVAILABLE_SIMPLE_MODELS = {
1618
"aero": "Aero",
1719
"plm": "PerceptionLM",
1820
"aria": "Aria",
@@ -77,14 +79,28 @@
7779
"vora": "VoRA",
7880
}
7981

82+
AVAILABLE_CHAT_TEMPLATE_MODELS = {"llava_hf": "LlavaHf"}
83+
8084

81-
def get_model(model_name):
82-
if model_name not in AVAILABLE_MODELS:
85+
def get_model(model_name, force_simple: bool = False):
86+
if model_name not in AVAILABLE_SIMPLE_MODELS and model_name not in AVAILABLE_CHAT_TEMPLATE_MODELS:
8387
raise ValueError(f"Model {model_name} not found in available models.")
8488

89+
if model_name in AVAILABLE_CHAT_TEMPLATE_MODELS:
90+
model_type = "chat"
91+
AVAILABLE_MODELS = AVAILABLE_CHAT_TEMPLATE_MODELS
92+
else:
93+
model_type = "simple"
94+
AVAILABLE_MODELS = AVAILABLE_SIMPLE_MODELS
95+
96+
# Override with force_simple if needed
97+
if force_simple:
98+
model_type = "simple"
99+
AVAILABLE_MODELS = AVAILABLE_SIMPLE_MODELS
100+
85101
model_class = AVAILABLE_MODELS[model_name]
86102
if "." not in model_class:
87-
model_class = f"lmms_eval.models.{model_name}.{model_class}"
103+
model_class = f"lmms_eval.models.{model_type}.{model_name}.{model_class}"
88104

89105
try:
90106
model_module, model_class = model_class.rsplit(".", 1)
@@ -99,5 +115,6 @@ def get_model(model_name):
99115
# Allow specifying other packages to import models from
100116
for plugin in os.environ["LMMS_EVAL_PLUGINS"].split(","):
101117
m = importlib.import_module(f"{plugin}.models")
118+
# For plugin users, this will be replaced by chat template model later
102119
for model_name, model_class in getattr(m, "AVAILABLE_MODELS").items():
103-
AVAILABLE_MODELS[model_name] = f"{plugin}.models.{model_name}.{model_class}"
120+
AVAILABLE_SIMPLE_MODELS[model_name] = f"{plugin}.models.{model_name}.{model_class}"

lmms_eval/models/chat/llava_hf.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import warnings
2+
from typing import List, Optional, Tuple, Union
3+
4+
import numpy as np
5+
import PIL
6+
import torch
7+
from accelerate import Accelerator, DistributedType
8+
from accelerate.state import AcceleratorState
9+
from decord import VideoReader, cpu
10+
from tqdm import tqdm
11+
from transformers import (
12+
AutoConfig,
13+
AutoProcessor,
14+
LlavaForConditionalGeneration,
15+
LlavaNextForConditionalGeneration,
16+
)
17+
18+
from lmms_eval import utils
19+
from lmms_eval.api.instance import Instance
20+
from lmms_eval.api.model import lmms
21+
from lmms_eval.api.registry import register_model
22+
from lmms_eval.protocol import ChatMessages
23+
24+
warnings.filterwarnings("ignore")
25+
26+
from loguru import logger as eval_logger
27+
28+
from lmms_eval.api.registry import register_model
29+
from lmms_eval.models.simple.llava_hf import LlavaHf as LlavaHfSimple
30+
31+
DEFAULT_IMAGE_TOKEN = "<image>"
32+
DEFAULT_VIDEO_TOKEN = "<video>"
33+
34+
# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0
35+
VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
36+
37+
38+
@register_model("llava_hf_chat")
39+
class LlavaHf(LlavaHfSimple):
40+
is_simple = False
41+
42+
def generate_until(self, requests: List[Instance]) -> List[str]:
43+
res = []
44+
45+
# A dummy collate here to sort by doc id
46+
def _collate(x):
47+
return x[2], x[2]
48+
49+
# we group requests by their generation_kwargs,
50+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
51+
# in the same batch.
52+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
53+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
54+
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
55+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
56+
for chunk in chunks:
57+
doc_to_messages, all_gen_kwargs, doc_id, task, split = zip(*chunk)
58+
task = task[0]
59+
split = split[0]
60+
chat_messages = [doc_to_messages[0](self.task_dict[task][split][ids]) for ids in doc_id]
61+
chat_messages: List[ChatMessages] = [ChatMessages(**{"messages": message}) for message in chat_messages]
62+
visuals = []
63+
videos = []
64+
for messages in chat_messages:
65+
visual, video, _ = messages.extract_media()
66+
visuals.append(visual)
67+
videos.append(video)
68+
visuals = self.flatten(visuals)
69+
videos = self.flatten(videos)
70+
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now"
71+
72+
# Apply chat template
73+
messages = chat_messages[0].model_dump()["messages"]
74+
text = self._image_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75+
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
76+
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")
77+
78+
if len(videos) == 0:
79+
videos = None
80+
inputs = self._image_processor(images=visuals, videos=videos, text=text, return_tensors="pt").to(self._device, self.model.dtype)
81+
82+
# we assume all gen kwargs in the batch are the same
83+
# this is safe to assume because the `grouper` object ensures it.
84+
gen_kwargs = all_gen_kwargs[0]
85+
86+
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
87+
if "max_new_tokens" not in gen_kwargs:
88+
gen_kwargs["max_new_tokens"] = 1024
89+
if "temperature" not in gen_kwargs:
90+
gen_kwargs["temperature"] = 0
91+
if "top_p" not in gen_kwargs:
92+
gen_kwargs["top_p"] = None
93+
if "num_beams" not in gen_kwargs:
94+
gen_kwargs["num_beams"] = 1
95+
try:
96+
cont = self.model.generate(
97+
**inputs,
98+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
99+
temperature=gen_kwargs["temperature"],
100+
top_p=gen_kwargs["top_p"],
101+
num_beams=gen_kwargs["num_beams"],
102+
max_new_tokens=gen_kwargs["max_new_tokens"],
103+
use_cache=self.use_cache,
104+
pad_token_id=self.eot_token_id,
105+
eos_token_id=self.eot_token_id,
106+
)
107+
cont = cont[:, inputs["input_ids"].shape[-1] :]
108+
except Exception as e:
109+
eval_logger.error(f"Error {e} in generating")
110+
cont = ""
111+
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
112+
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
113+
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
114+
115+
res.append(text_outputs)
116+
self.cache_hook.add_partial("generate_until", (text, gen_kwargs), text_outputs)
117+
pbar.update(1)
118+
# reorder this group of results back to original unsorted form
119+
res = re_ords.get_original(res)
120+
121+
pbar.close()
122+
return res

0 commit comments

Comments
 (0)