Skip to content

Commit e49e49f

Browse files
committed
type check generators for Turn patterns
1 parent 1e84ad9 commit e49e49f

File tree

5 files changed

+59
-15
lines changed

5 files changed

+59
-15
lines changed

garak/generators/huggingface.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def __init__(self, name="", config_root=_config):
242242
),
243243
max_value=125,
244244
)
245-
def _call_model(self, prompt: Turn, generations_this_call: int = 1) -> List[Turn]:
245+
def _call_model(
246+
self, prompt: Turn, generations_this_call: int = 1
247+
) -> List[Turn | None]:
246248
import json
247249
import requests
248250

@@ -350,7 +352,9 @@ def __init__(self, name="", config_root=_config):
350352
),
351353
max_value=125,
352354
)
353-
def _call_model(self, prompt: Turn, generations_this_call: int = 1) -> List[Turn]:
355+
def _call_model(
356+
self, prompt: Turn, generations_this_call: int = 1
357+
) -> List[Turn | None]:
354358
import requests
355359

356360
payload = {
@@ -440,7 +444,9 @@ def _clear_client(self):
440444
self.tokenizer = None
441445
self.generation_config = None
442446

443-
def _call_model(self, prompt: Turn, generations_this_call: int = 1) -> List[Turn]:
447+
def _call_model(
448+
self, prompt: Turn, generations_this_call: int = 1
449+
) -> List[Turn | None]:
444450
self._load_client()
445451
self.generation_config.max_new_tokens = self.max_tokens
446452
self.generation_config.do_sample = self.hf_args["do_sample"]

garak/generators/replicate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ class InferenceEndpoint(ReplicateGenerator):
7979
backoff.fibo, replicate.exceptions.ReplicateError, max_value=70
8080
)
8181
def _call_model(
82-
self, prompt, generations_this_call: int = 1
83-
) -> List[Union[str, None]]:
82+
self, prompt: Turn, generations_this_call: int = 1
83+
) -> List[Union[Turn, None]]:
8484
deployment = self.replicate.deployments.get(self.name)
8585
prediction = deployment.predictions.create(
8686
input={
87-
"prompt": prompt,
87+
"prompt": prompt.text,
8888
"max_length": self.max_tokens,
8989
"temperature": self.temperature,
9090
"top_p": self.top_p,
@@ -98,7 +98,7 @@ def _call_model(
9898
raise IOError(
9999
"Replicate endpoint didn't generate a response. Make sure the endpoint is active."
100100
) from exc
101-
return [response]
101+
return [Turn(r) for r in response]
102102

103103

104104
DEFAULT_CLASS = "ReplicateGenerator"

garak/generators/test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ class Blank(Generator):
1818
generator_family_name = "Test"
1919
name = "Blank"
2020

21-
def _call_model(self, prompt: Turn, generations_this_call: int = 1) -> List[Turn]:
21+
def _call_model(
22+
self, prompt: Turn, generations_this_call: int = 1
23+
) -> List[Turn | None]:
2224
return [Turn("")] * generations_this_call
2325

2426

@@ -29,7 +31,9 @@ class Repeat(Generator):
2931
generator_family_name = "Test"
3032
name = "Repeat"
3133

32-
def _call_model(self, prompt: Turn, generations_this_call: int = 1) -> List[Turn]:
34+
def _call_model(
35+
self, prompt: Turn, generations_this_call: int = 1
36+
) -> List[Turn | None]:
3337
return [prompt] * generations_this_call
3438

3539

@@ -41,7 +45,9 @@ class Single(Generator):
4145
name = "Single"
4246
test_generation_string = "ELIM"
4347

44-
def _call_model(self, prompt: Turn, generations_this_call: int = 1) -> List[Turn]:
48+
def _call_model(
49+
self, prompt: Turn, generations_this_call: int = 1
50+
) -> List[Turn | None]:
4551
if generations_this_call == 1:
4652
return [Turn(self.test_generation_string)]
4753
else:
@@ -71,7 +77,9 @@ class BlankVision(Generator):
7177
name = "BlankVision"
7278
modality = {"in": {"text", "image"}, "out": {"text"}}
7379

74-
def _call_model(self, prompt: Turn, generations_this_call: int = 1) -> List[Turn]:
80+
def _call_model(
81+
self, prompt: Turn, generations_this_call: int = 1
82+
) -> List[Turn | None]:
7583
return [Turn("")] * generations_this_call
7684

7785

garak/generators/watsonx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,14 @@ def _validate_env_var(self):
126126
return super()._validate_env_var()
127127

128128
def _call_model(
129-
self, prompt: str, generations_this_call: int = 1
130-
) -> List[Union[str, None]]:
129+
self, prompt: Turn, generations_this_call: int = 1
130+
) -> List[Union[Turn, None]]:
131131
if not self.bearer_token:
132132
self._set_bearer_token()
133133

134134
# Check if message is empty. If it is, append null byte.
135-
if not prompt:
136-
prompt = "\x00"
135+
if not prompt or not prompt.text:
136+
prompt = Turn("\x00")
137137
print(
138138
"WARNING: Empty prompt was found. Null byte character appended to prevent API failure."
139139
)

tests/generators/test_generators.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import pytest
77
import random
88

9+
from typing import List, Union
10+
911
from garak import _plugins
1012
from garak import _config
1113
from garak.attempt import Turn
@@ -222,3 +224,31 @@ def test_instantiate_generators(classname):
222224
m = importlib.import_module("garak." + ".".join(classname.split(".")[:-1]))
223225
g = getattr(m, classname.split(".")[-1])(config_root=config_root)
224226
assert isinstance(g, Generator)
227+
228+
229+
NON_CONVERSATION_GENERATORS = [
230+
classname
231+
for classname in GENERATORS
232+
if not ("openai" in classname or "groq" in classname or "azure" in classname)
233+
]
234+
235+
236+
@pytest.mark.parametrize("classname", NON_CONVERSATION_GENERATORS)
237+
def test_generator_signature(classname):
238+
_, namespace, klass = classname.split(".")
239+
m = importlib.import_module(f"garak.generators.{namespace}")
240+
g = getattr(m, klass)
241+
generate_signature = inspect.signature(g.generate)
242+
assert (
243+
generate_signature.parameters.get("prompt").annotation == Turn
244+
), "generate should take a Turn and return list of Turns or Nones"
245+
assert (
246+
generate_signature.return_annotation == List[Union[None, Turn]]
247+
), "generate should take a Turn and return list of Turns or Nones"
248+
_call_model_signature = inspect.signature(g._call_model)
249+
assert (
250+
_call_model_signature.parameters.get("prompt").annotation == Turn
251+
), "_call_model should take a Turn and return list of Turns or Nones"
252+
assert (
253+
_call_model_signature.return_annotation == List[Union[None, Turn]]
254+
), "_call_model should take a Turn and return list of Turns or Nones"

0 commit comments

Comments
 (0)