Skip to content

generator: vision nims #959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion garak/attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def __setattr__(self, name: str, value: Any) -> None:
if name == "prompt":
if value is None:
raise TypeError("'None' prompts are not valid")
assert isinstance(value, str)
self._add_first_turn("user", value)

elif name == "outputs":
Expand Down
8 changes: 3 additions & 5 deletions garak/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,10 @@ def _pre_generate_hook(self):

@staticmethod
def _verify_model_result(result: List[Union[str, None]]):
assert isinstance(
result, list
), "_call_model must return a list"
assert isinstance(result, list), "_call_model must return a list"
assert (
len(result) == 1
), "_call_model must return a list of one item when invoked as _call_model(prompt, 1)"
len(result) == 1
), f"_call_model must return a list of one item when invoked as _call_model(prompt, 1), got {result}"
assert (
isinstance(result[0], str) or result[0] is None
), "_call_model's item must be a string or None"
Expand Down
6 changes: 5 additions & 1 deletion garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,10 @@ def _call_model(


class LLaVA(Generator, HFCompatible):
"""Get LLaVA ([ text + image ] -> text) generations"""
"""Get LLaVA ([ text + image ] -> text) generations

NB. This should be use with strict modality matching - generate() doesn't
support text-only prompts."""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"max_tokens": 4000,
Expand Down Expand Up @@ -570,6 +573,7 @@ def __init__(self, name="", config_root=_config):
def generate(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:

text_prompt = prompt["text"]
try:
image_prompt = Image.open(prompt["image"])
Expand Down
53 changes: 53 additions & 0 deletions garak/generators/nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def _clear_client(self):
self.generator = None
self.client = None

def _prepare_prompt(self, prompt):
return prompt

def _call_model(
self, prompt: str | List[dict], generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand All @@ -80,8 +83,17 @@ def _call_model(
if self.vary_temp_each_call:
self.temperature = random.random()

prompt = self._prepare_prompt(prompt)
if prompt is None:
# if we didn't get a valid prompt, don't process it, and send the NoneType(s) downstream
return [None] * generations_this_call

try:
result = super()._call_model(prompt, generations_this_call)
except openai.UnprocessableEntityError as uee:
msg = "Model call didn't match endpoint expectations, see log"
logging.critical(msg, exc_info=uee)
raise GarakException(f"🛑 {msg}") from uee
# except openai.NotFoundError as oe:
except Exception as oe:
msg = "NIM endpoint not found. Is the model name spelled correctly?"
Expand Down Expand Up @@ -125,4 +137,45 @@ def _load_client(self):
self.generator = self.client.completions


class Vision(NVOpenAIChat):
"""Wrapper for text+image to text NIMs. Expects NIM_API_KEY environment variable.

Following generators.huggingface.LLaVa, expects prompts to be a dict with keys
"text" and "image"; text holds the text prompt, image holds a path to the image."""

DEFAULT_PARAMS = NVOpenAIChat.DEFAULT_PARAMS | {
"suppressed_params": {"n", "frequency_penalty", "presence_penalty", "stop"},
"max_image_len": 180_000,
}

modality = {"in": {"text", "image"}, "out": {"text"}}

def _prepare_prompt(self, prompt):
import base64

if isinstance(prompt, str):
prompt = {"text": prompt, "image": None}

text = prompt["text"]
image_filename = prompt["image"]
if image_filename is not None:
with open(image_filename, "rb") as f:
image_b64 = base64.b64encode(f.read()).decode()

if len(image_b64) > self.max_image_len:
logging.error(
"Image %s exceeds length limit. To upload larger images, use the assets API (not yet supported)",
image_filename,
)
return None

image_extension = prompt["image"].split(".")[-1].lower()
if image_extension == "jpg": # image/jpg is not a valid mimetype
image_extension = "jpeg"
text = (
text + f' <img src="data:image/{image_extension};base64,{image_b64}" />'
)
return text


DEFAULT_CLASS = "NVOpenAIChat"
23 changes: 20 additions & 3 deletions garak/harnesses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class Harness(Configurable):

active = True

DEFAULT_PARAMS = {}
DEFAULT_PARAMS = {
"strict_modality_match": False,
}
Comment on lines +32 to +34
Copy link
Collaborator

@jmartin-tech jmartin-tech Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to be in the harness for now, however I see some possible contention as noted in the description questions.

I suspect strict modality matching requirements may lean toward being a responsibility of the probe to define requirements. I suspect the technique employed in the probe will set the requirement for strict match as generators may often support various modalities.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interested to hear more


def __init__(self, config_root=_config):
self._load_config(config_root)
Expand Down Expand Up @@ -96,8 +98,12 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None:
logging.debug("harness: probe start for %s", probe.probename)
if not probe:
continue
# TODO: refactor this to allow `compatible` probes instead of direct match
if probe.modality["in"] != model.modality["in"]:

modality_match = _modality_match(
probe.modality["in"], model.modality["in"], self.strict_modality_match
)

if not modality_match:
logging.warning(
"probe skipped due to modality mismatch: %s - model expects %s",
probe.probename,
Expand Down Expand Up @@ -136,3 +142,14 @@ def run(self, model, probes, detectors, evaluator, announce_probe=True) -> None:
evaluator.evaluate(attempt_results)

logging.debug("harness: probe list iteration completed")


def _modality_match(probe_modality, generator_modality, strict):
if strict:
# must be perfect match
return probe_modality == generator_modality
else:
# everything probe wants must be accepted by model
return set(probe_modality).intersection(generator_modality) == set(
probe_modality
)
22 changes: 22 additions & 0 deletions tests/harnesses/test_harnesses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from garak import _plugins

import garak.harnesses.base

HARNESSES = [
classname for (classname, active) in _plugins.enumerate_plugins("harnesses")
]
Expand All @@ -25,3 +27,23 @@ def test_buff_structure(classname):
if k not in c._supported_params:
unsupported_defaults.append(k)
assert unsupported_defaults == []


def test_harness_modality_match():
t = {"text"}
ti = {"text", "image"}
tv = {"text", "vision"}
tvi = {"text", "vision", "image"}

# probe, generator
assert garak.harnesses.base._modality_match(t, t, True) is True
assert garak.harnesses.base._modality_match(ti, ti, True) is True
assert garak.harnesses.base._modality_match(t, tv, True) is False
assert garak.harnesses.base._modality_match(ti, t, True) is False

# when strict is false, generator must support all probe modalities, but can also support more
assert garak.harnesses.base._modality_match(t, t, False) is True
assert garak.harnesses.base._modality_match(ti, t, False) is False
assert garak.harnesses.base._modality_match(t, tvi, False) is True
assert garak.harnesses.base._modality_match(ti, tvi, False) is True
assert garak.harnesses.base._modality_match(t, ti, False) is True
Loading