Skip to content

Commit 3883882

Browse files
committed
Set up auto doc to messages and chat models
1 parent 26bbaa1 commit 3883882

File tree

8 files changed

+244
-35
lines changed

8 files changed

+244
-35
lines changed

lmms_eval/__main__.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -420,23 +420,6 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
420420
elif args.tasks == "list_subtasks":
421421
eval_logger.info(task_manager.list_all_tasks(list_groups=False, list_tags=False))
422422
sys.exit()
423-
elif args.tasks == "list_with_num":
424-
log_message = (
425-
"\n" + "=" * 70 + "\n" + "\n\tYou are trying to check all the numbers in each task." + "\n\tThis action will download the complete dataset." + "\n\tIf the results are not clear initially, call this again." + "\n\n" + "=" * 70
426-
)
427-
eval_logger.info(log_message)
428-
for task_name in sorted(task_manager.list_all_tasks()):
429-
try:
430-
task_dict = get_task_dict([task_name], model_name="llava")
431-
task_obj = task_dict[task_name]
432-
if type(task_obj) == tuple:
433-
group, task_obj = task_obj
434-
if task_obj is None:
435-
continue
436-
eval_logger.info(f"\nTask : {task_obj.config.task}\n - #num : {len(task_obj.test_docs()) if task_obj.has_test_docs() else len(task_obj.validation_docs())}")
437-
except Exception as e:
438-
eval_logger.debug(f"\nTask : {task_name} fail to load \n Exception : \n {e}")
439-
sys.exit()
440423
else:
441424
if os.path.isdir(args.tasks):
442425
import glob

lmms_eval/api/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616

1717
class lmms(abc.ABC):
18+
is_simple: bool = True
19+
1820
def __init__(self) -> None:
1921
"""Defines the interface that should be implemented by all lmms subclasses.
2022
lmmss are assumed to take image-text as input and yield strings as output

lmms_eval/api/task.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from datasets import Audio, DownloadConfig, Image, Sequence
3333
from huggingface_hub import snapshot_download
3434
from loguru import logger as eval_logger
35+
from PIL import Image as PIL_Image
3536
from PIL import ImageFile
3637
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed
3738
from tqdm import tqdm
@@ -91,6 +92,7 @@ class TaskConfig(dict):
9192
doc_to_text: Union[Callable, str] = None
9293
doc_to_target: Union[Callable, str] = None
9394
doc_to_choice: Union[Callable, str, dict, list] = None
95+
doc_to_messages: Callable = None
9496
process_results: Union[Callable, str] = None
9597
use_prompt: str = None
9698
description: str = ""
@@ -1634,3 +1636,41 @@ def task_name(self) -> Any:
16341636

16351637
def __repr__(self):
16361638
return f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," f"output_type={self.OUTPUT_TYPE}," f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," f"num_samples={len(self.eval_docs)})"
1639+
1640+
1641+
class ConfigurableMessagesTask(ConfigurableTask):
1642+
def doc_to_messages(self, doc: dict) -> Union[int, str, list]:
1643+
if callable(self.config.doc_to_messages):
1644+
return (
1645+
self.config.doc_to_messages(doc, self.lmms_eval_specific_kwargs)
1646+
if self.lmms_eval_specific_kwargs is not None and len(inspect.signature(self.config.doc_to_messages).parameters) == 2
1647+
else self.config.doc_to_messages(
1648+
doc,
1649+
)
1650+
)
1651+
elif self.config.doc_to_messages is None and self.config.doc_to_visual is not None and self.config.doc_to_text is not None:
1652+
# An auto doc to messages function
1653+
def auto_doc_to_messages(doc):
1654+
visuals = self.doc_to_visual(doc)
1655+
text = self.doc_to_text(doc)
1656+
messages = [{"role": "user", "content": []}]
1657+
content = []
1658+
for visual in visuals:
1659+
if isinstance(visual, PIL_Image.Image):
1660+
content.append({"type": "image", "url": visual})
1661+
content.append({"type": "text", "text": text})
1662+
messages[0]["content"] = content
1663+
return messages
1664+
1665+
return auto_doc_to_messages(doc)
1666+
else:
1667+
# eval_logger.warning("Note that doc_to_visual was called but not set in config. Please check if this is a text-only task.")
1668+
return self.config.doc_to_messages
1669+
1670+
def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]:
1671+
split = kwargs.get("metadata").get("split")
1672+
# kwargs.pop("split")
1673+
assert self.OUTPUT_TYPE == "generate_until", "Currently messages is used for generation only"
1674+
1675+
arguments = (self.doc_to_messages, copy.deepcopy(self.config.generation_kwargs), doc_id, self.config.task, split)
1676+
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)

lmms_eval/evaluator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ def simple_evaluate(
172172
if task_manager is None:
173173
task_manager = TaskManager(verbosity, model_name=model)
174174

175-
task_dict = get_task_dict(tasks, task_manager)
176-
177175
if isinstance(model, str):
178176
if model_args is None:
179177
model_args = ""
@@ -187,6 +185,8 @@ def simple_evaluate(
187185
)
188186
elif isinstance(model, lmms_eval.api.model.lmms):
189187
lm = model
188+
task_type = "simple" if lm.is_simple else "chat"
189+
task_dict = get_task_dict(tasks, task_manager, task_type)
190190

191191
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
192192
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
@@ -551,8 +551,7 @@ def evaluate(
551551
ensure_ascii=False,
552552
)
553553
),
554-
"prompt_hash": hash_string(requests[0].arguments[0]),
555-
"target_hash": hash_string(str(target)),
554+
# Removing prompt hash and target hash here
556555
}
557556
example.update(metrics)
558557
task_output.logged_samples.append(example)

lmms_eval/models/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib
22
import os
33
import sys
4+
from typing import Literal
45

56
import hf_transfer
67
from loguru import logger
@@ -10,7 +11,8 @@
1011
logger.remove()
1112
logger.add(sys.stdout, level="WARNING")
1213

13-
AVAILABLE_MODELS = {
14+
15+
AVAILABLE_SIMPLE_MODELS = {
1416
"aero": "Aero",
1517
"plm": "PerceptionLM",
1618
"aria": "Aria",
@@ -75,14 +77,23 @@
7577
"vora": "VoRA",
7678
}
7779

80+
AVAILABLE_CHAT_TEMPLATE_MODELS = {"llava_hf": "LlavaHf"}
81+
7882

7983
def get_model(model_name):
80-
if model_name not in AVAILABLE_MODELS:
84+
if model_name not in AVAILABLE_SIMPLE_MODELS and model_name not in AVAILABLE_CHAT_TEMPLATE_MODELS:
8185
raise ValueError(f"Model {model_name} not found in available models.")
8286

87+
if model_name in AVAILABLE_CHAT_TEMPLATE_MODELS:
88+
model_type = "chat"
89+
AVAILABLE_MODELS = AVAILABLE_CHAT_TEMPLATE_MODELS
90+
else:
91+
model_type = "simple"
92+
AVAILABLE_MODELS = AVAILABLE_SIMPLE_MODELS
93+
8394
model_class = AVAILABLE_MODELS[model_name]
8495
if "." not in model_class:
85-
model_class = f"lmms_eval.models.{model_name}.{model_class}"
96+
model_class = f"lmms_eval.models.{model_type}.{model_name}.{model_class}"
8697

8798
try:
8899
model_module, model_class = model_class.rsplit(".", 1)
@@ -97,5 +108,6 @@ def get_model(model_name):
97108
# Allow specifying other packages to import models from
98109
for plugin in os.environ["LMMS_EVAL_PLUGINS"].split(","):
99110
m = importlib.import_module(f"{plugin}.models")
111+
# For plugin users, this will be replaced by chat template model later
100112
for model_name, model_class in getattr(m, "AVAILABLE_MODELS").items():
101-
AVAILABLE_MODELS[model_name] = f"{plugin}.models.{model_name}.{model_class}"
113+
AVAILABLE_SIMPLE_MODELS[model_name] = f"{plugin}.models.{model_name}.{model_class}"

lmms_eval/models/chat/llava_hf.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import warnings
2+
from typing import List, Optional, Tuple, Union
3+
4+
import numpy as np
5+
import PIL
6+
import torch
7+
from accelerate import Accelerator, DistributedType
8+
from accelerate.state import AcceleratorState
9+
from decord import VideoReader, cpu
10+
from tqdm import tqdm
11+
from transformers import (
12+
AutoConfig,
13+
AutoProcessor,
14+
LlavaForConditionalGeneration,
15+
LlavaNextForConditionalGeneration,
16+
)
17+
18+
from lmms_eval import utils
19+
from lmms_eval.api.instance import Instance
20+
from lmms_eval.api.model import lmms
21+
from lmms_eval.api.registry import register_model
22+
from lmms_eval.protocol import ChatMessages
23+
24+
warnings.filterwarnings("ignore")
25+
26+
from loguru import logger as eval_logger
27+
28+
from lmms_eval.api.registry import register_model
29+
from lmms_eval.models.simple.llava_hf import LlavaHf as LlavaHfSimple
30+
31+
DEFAULT_IMAGE_TOKEN = "<image>"
32+
DEFAULT_VIDEO_TOKEN = "<video>"
33+
34+
# Default chat for llava-hf/llava-1.5 models: https://huggingface.co/collections/llava-hf/llava-15-65f762d5b6941db5c2ba07e0
35+
VICUNA_CHAT_TEMPLATE = "{% for message in messages %}{% if loop.index0 == 0 %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {{ message['content'] }} {% elif message['role'] == 'user' %}USER: {{ message['content'] }} {% else %} ASSISTANT: {{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
36+
37+
38+
@register_model("llava_hf_chat")
39+
class LlavaHf(LlavaHfSimple):
40+
is_simple = False
41+
42+
def generate_until(self, requests: List[Instance]) -> List[str]:
43+
res = []
44+
45+
# A dummy collate here to sort by doc id
46+
def _collate(x):
47+
return x[2], x[2]
48+
49+
# we group requests by their generation_kwargs,
50+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
51+
# in the same batch.
52+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
53+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
54+
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
55+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
56+
for chunk in chunks:
57+
doc_to_messages, all_gen_kwargs, doc_id, task, split = zip(*chunk)
58+
task = task[0]
59+
split = split[0]
60+
chat_messages = [doc_to_messages[0](self.task_dict[task][split][ids]) for ids in doc_id]
61+
chat_messages: List[ChatMessages] = [ChatMessages(**{"messages": message}) for message in chat_messages]
62+
visuals = []
63+
videos = []
64+
for messages in chat_messages:
65+
visual, video, _ = messages.extract_media()
66+
visuals.append(visual)
67+
videos.append(video)
68+
visuals = self.flatten(visuals)
69+
videos = self.flatten(videos)
70+
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now"
71+
72+
# Apply chat template
73+
messages = chat_messages[0].model_dump()["messages"]
74+
text = self._image_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75+
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
76+
eval_logger.debug(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")
77+
78+
if len(videos) == 0:
79+
videos = None
80+
inputs = self._image_processor(images=visuals, videos=videos, text=text, return_tensors="pt").to(self._device, self.model.dtype)
81+
82+
# we assume all gen kwargs in the batch are the same
83+
# this is safe to assume because the `grouper` object ensures it.
84+
gen_kwargs = all_gen_kwargs[0]
85+
86+
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
87+
if "max_new_tokens" not in gen_kwargs:
88+
gen_kwargs["max_new_tokens"] = 1024
89+
if "temperature" not in gen_kwargs:
90+
gen_kwargs["temperature"] = 0
91+
if "top_p" not in gen_kwargs:
92+
gen_kwargs["top_p"] = None
93+
if "num_beams" not in gen_kwargs:
94+
gen_kwargs["num_beams"] = 1
95+
try:
96+
cont = self.model.generate(
97+
**inputs,
98+
do_sample=True if gen_kwargs["temperature"] > 0 else False,
99+
temperature=gen_kwargs["temperature"],
100+
top_p=gen_kwargs["top_p"],
101+
num_beams=gen_kwargs["num_beams"],
102+
max_new_tokens=gen_kwargs["max_new_tokens"],
103+
use_cache=self.use_cache,
104+
pad_token_id=self.eot_token_id,
105+
eos_token_id=self.eot_token_id,
106+
)
107+
cont = cont[:, inputs["input_ids"].shape[-1] :]
108+
except Exception as e:
109+
eval_logger.error(f"Error {e} in generating")
110+
cont = ""
111+
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
112+
if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
113+
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
114+
115+
res.append(text_outputs)
116+
self.cache_hook.add_partial("generate_until", (text, gen_kwargs), text_outputs)
117+
pbar.update(1)
118+
# reorder this group of results back to original unsorted form
119+
res = re_ords.get_original(res)
120+
121+
pbar.close()
122+
return res

lmms_eval/protocol.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Dict, List, Literal, Union
2+
3+
from PIL import Image
4+
from pydantic import BaseModel
5+
6+
7+
class ChatTextContent(BaseModel):
8+
type: Literal["text"] = "text"
9+
text: str
10+
11+
12+
class ChatImageContent(BaseModel):
13+
type: Literal["image"] = "image"
14+
url: Any
15+
16+
def model_dump(self, **kwargs):
17+
content = super().model_dump(**kwargs)
18+
# Some model may need this placeholder for hf_chat_template
19+
content["image_url"] = "placeholder"
20+
return content
21+
22+
23+
ChatContent = Union[ChatTextContent, ChatImageContent]
24+
25+
26+
class ChatMessage(BaseModel):
27+
role: Literal["user", "system", "assistant"]
28+
content: List[ChatContent]
29+
30+
31+
class ChatMessages(BaseModel):
32+
messages: List[ChatMessage]
33+
34+
def extract_media(self):
35+
images = []
36+
videos = []
37+
audios = []
38+
39+
for message in self.messages:
40+
for content in message.content:
41+
if content.type == "image":
42+
images.append(content.url)
43+
44+
return images, videos, audios

0 commit comments

Comments
 (0)