Skip to content

feature: deferred loading and requirement pruning #1199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3000d4c
draft postponed import pattern for cohere generator
leondz May 2, 2025
757e0f3
move extra dependency requirements into classdefs, mediate requiremen…
leondz May 5, 2025
9310d0a
actually do the plugin dep load
leondz May 5, 2025
dac569e
migrate generators to 'extra dependencies' pattern
leondz May 5, 2025
35e93fc
prune dupe lazyload
leondz May 7, 2025
bf7f36b
extra_dependency_names in all plugins
leondz May 7, 2025
6a39b0c
active must be False for Probes using extra modules
leondz May 7, 2025
56c6182
make PIL optional in generators.huggingface.LLaVA
leondz May 7, 2025
3657e04
move optional load fail to ModuleNotFoundError
leondz May 7, 2025
865d604
add _load/_clear_deps() into base generator and _load/_clear client
leondz May 7, 2025
d61957d
put the MNFE where it belongs
leondz May 7, 2025
8a7051e
backoff exception placeholder must inherit base exception
leondz May 7, 2025
60775f6
test for reqs presence in pyproject.toml, requirements.txt
leondz May 7, 2025
31e98d4
handle hyphen in pypi pkg names
leondz May 7, 2025
75babb7
rm optional plugin deps
leondz May 7, 2025
83f551a
skip generator tests if optional deps absent
leondz May 8, 2025
dd51196
support sub-package deps
leondz May 8, 2025
b33a46c
scope optimum to nvidia
leondz May 8, 2025
de5b3f1
move import function to _load_deps
leondz May 8, 2025
19c31fe
rm import handling in langchain
leondz May 8, 2025
54fabc5
amend optimum to be nvidia flavour
leondz May 8, 2025
ffac714
dry - use garak._plugins.PLUGIN_TYPES as canonical def of 1st class p…
leondz May 8, 2025
97c8160
unify backoff exception pattern mediated via garak GeneratorBackoffEx…
leondz May 9, 2025
1d4e69c
skip instantiation when modules not present
leondz May 9, 2025
6164bc5
catch straggling backoff exception wrappings
leondz May 9, 2025
85fb7c3
Merge branch 'main' into update/optional_imports
leondz May 9, 2025
0402116
use isinstance for exception matching
leondz May 9, 2025
e287fe9
don't backoff on 404
leondz May 9, 2025
6339648
merge in our good pal main
leondz May 16, 2025
76b1774
switch to pyproject; get tests deps if testing
leondz May 16, 2025
ca133e4
add [dev] target
leondz May 16, 2025
8e8a5b9
add required jsonschema that was previously implicit from now-optiona…
leondz May 16, 2025
aa7500a
specify versions; move to secure versions cf. #1207
leondz May 16, 2025
4f2e5ef
skip internal config mappings for req consistency testing
leondz May 16, 2025
69cfef2
skip test option for non-test workflow
leondz May 16, 2025
a1da5ed
skip ollama tests if no module
leondz May 16, 2025
3a8605d
rm spurious dep check
leondz May 16, 2025
d2d17ad
straggling spurious check
leondz May 16, 2025
13974b8
Merge branch 'main' into update/optional_imports
leondz May 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,18 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
) from ve
else:
return False

full_plugin_name = ".".join((category, module_name, plugin_class_name))

# check cache for optional imports
extra_dependency_names = PluginCache.instance()[category][full_plugin_name][
"extra_dependency_names"
]
if len(extra_dependency_names) > 0:
for dependency_module_name in extra_dependency_names:
if importlib.util.find_spec(dependency_module_name) is None:
_import_failed(dependency_module_name, full_plugin_name)

module_path = f"garak.{category}.{module_name}"
try:
mod = importlib.import_module(module_path)
Expand All @@ -426,6 +438,7 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
if plugin_instance is None:
plugin_instance = klass(config_root=config_root)
PluginProvider.storeInstance(plugin_instance, config_root)

except Exception as e:
logging.warning(
"Exception instantiating %s.%s: %s",
Expand All @@ -440,3 +453,20 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
return False

return plugin_instance


def load_optional_module(module_name: str):
try:
m = importlib.import_module(module_name)
except ModuleNotFoundError:
requesting_module = Path(inspect.stack()[1].filename).name.replace(".py", "")
_import_failed(module_name, requesting_module)
return m


def _import_failed(import_module: str, calling_module: str):
msg = f"⛔ Plugin '{calling_module}' requires Python module '{import_module}' but this isn't installed/available."
hint = f"💡 Try 'pip install {import_module}' to get it."
logging.critical(msg)
print(msg + "\n" + hint)
raise ModuleNotFoundError(msg)
2 changes: 2 additions & 0 deletions garak/buffs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Buff(Configurable):
doc_uri = ""
lang = None # set of languages this buff should be constrained to
active = True
# list of strings naming modules required but not explicitly in garak by default
extra_dependency_names = []

DEFAULT_PARAMS = {}

Expand Down
2 changes: 2 additions & 0 deletions garak/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Detector(Configurable):
accuracy = None
active = True
tags = [] # list of taxonomy categories per the MISP format
# list of strings naming modules required but not explicitly in garak by default
extra_dependency_names = []

# support mainstream any-to-any large models
# legal element for str list `modality['in']`: 'text', 'image', 'audio', 'video', '3d'
Expand Down
4 changes: 4 additions & 0 deletions garak/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ class ConfigFailure(GarakException):

class PayloadFailure(GarakException):
"""Problem instantiating/using payloads"""


class GeneratorBackoffExceptionPlaceholder(GarakException):
"""Placeholder used for lazy-loaded exceptions"""
1 change: 1 addition & 0 deletions garak/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _validate_env_var(self):
return super()._validate_env_var()

def _load_client(self):
self._load_deps()
if self.model_name in openai_model_mapping:
self.model_name = openai_model_mapping[self.model_name]

Expand Down
27 changes: 26 additions & 1 deletion garak/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Generator(Configurable):
supports_multiple_generations = (
False # can more than one generation be extracted per request?
)
# list of strings naming modules required but not explicitly in garak by default
extra_dependency_names = []

def __init__(self, name="", config_root=_config):
self._load_config(config_root)
Expand All @@ -63,6 +65,29 @@ def __init__(self, name="", config_root=_config):
f"🦜 loading {Style.BRIGHT}{Fore.LIGHTMAGENTA_EX}generator{Style.RESET_ALL}: {self.generator_family_name}: {self.name}"
)
logging.info("generator init: %s", self)
self._load_deps()

def _load_deps(self):
# load external dependencies. should be invoked at construction and
# in _client_load (if used)
for extra_dependency in self.extra_dependency_names:
extra_dep_name = extra_dependency.replace(".", "_").replace("-", "_")
if (
not hasattr(self, extra_dep_name)
or getattr(self, extra_dep_name) is None
):
setattr(
self,
extra_dep_name,
garak._plugins.load_optional_module(extra_dependency),
)

def _clear_deps(self):
# unload external dependencies from class. should be invoked before
# serialisation, esp. in _clear_client (if used)
for extra_dependency in self.extra_dependency_names:
extra_dep_name = extra_dependency.replace(".", "_")
setattr(self, extra_dep_name, None)
Comment on lines +70 to +90
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in Configurable instead, since it can/should be used across all base classes?


def _call_model(
self, prompt: str, generations_this_call: int = 1
Expand Down Expand Up @@ -101,7 +126,7 @@ def _prune_skip_sequences(self, outputs: List[str | None]) -> List[str | None]:
)
rx_missing_final = re.escape(self.skip_seq_start) + ".*?$"
rx_missing_start = ".*?" + re.escape(self.skip_seq_end)

if self.skip_seq_start == "":
complete_seqs_removed = [
(
Expand Down
14 changes: 11 additions & 3 deletions garak/generators/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@
from typing import List, Union

import backoff
import cohere
import tqdm

from garak import _config
import garak._plugins
from garak.exception import GeneratorBackoffExceptionPlaceholder
from garak.generators.base import Generator


COHERE_GENERATION_LIMIT = (
5 # c.f. https://docs.cohere.com/reference/generate 18 may 2023
)

cohere_exception = GeneratorBackoffExceptionPlaceholder


class CohereGenerator(Generator):
"""Interface to Cohere's python library for their text2text model.
Expand All @@ -38,22 +41,27 @@ class CohereGenerator(Generator):
"presence_penalty": 0.0,
"stop": [],
}
extra_dependency_names = ["cohere"]

supports_multiple_generations = True
generator_family_name = "Cohere"

def __init__(self, name="command", config_root=_config):

self.name = name
self.fullname = f"Cohere {self.name}"

super().__init__(self.name, config_root=config_root)

global cohere_exception
cohere_exception = self.cohere.error.CohereAPIError

logging.debug(
"Cohere generation request limit capped at %s", COHERE_GENERATION_LIMIT
)
self.generator = cohere.Client(self.api_key)
self.generator = self.cohere.Client(self.api_key)

@backoff.on_exception(backoff.fibo, cohere.error.CohereAPIError, max_value=70)
@backoff.on_exception(backoff.fibo, cohere_exception, max_value=70)
def _call_cohere_api(self, prompt, request_size=COHERE_GENERATION_LIMIT):
"""as of jun 2 2023, empty prompts raise:
cohere.error.CohereAPIError: invalid request: prompt must be at least 1 token long
Expand Down
1 change: 1 addition & 0 deletions garak/generators/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class GroqChat(OpenAICompatible):
generator_family_name = "Groq"

def _load_client(self):
self._load_deps()
self.client = openai.OpenAI(base_url=self.uri, api_key=self.api_key)
if self.name in ("", None):
raise ValueError(
Expand Down
18 changes: 6 additions & 12 deletions garak/generators/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,21 @@ class NeMoGuardrails(Generator):

supports_multiple_generations = False
generator_family_name = "Guardrails"
extra_dependency_names = ["nemoguardrails"]

def __init__(self, name="", config_root=_config):
# another class that may need to skip testing due to non required dependency
try:
from nemoguardrails import RailsConfig, LLMRails
from nemoguardrails.logging.verbose import set_verbose
except ImportError as e:
raise NameError(
"You must first install NeMo Guardrails using `pip install nemoguardrails`."
) from e

self.name = name
self._load_config(config_root)
self.fullname = f"Guardrails {self.name}"

super().__init__(self.name, config_root=config_root)

set_verbose = self.nemoguardrails.logging.verbose.set_verbose
# Currently, we use the model_name as the path to the config
with redirect_stderr(io.StringIO()) as f: # quieten the tqdm
config = RailsConfig.from_path(self.name)
self.rails = LLMRails(config=config)

super().__init__(self.name, config_root=config_root)
config = self.nemoguardrails.RailsConfig.from_path(self.name)
self.rails = self.nemoguardrails.LLMRails(config=config)

def _call_model(
self, prompt: str, generations_this_call: int = 1
Expand Down
22 changes: 12 additions & 10 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import backoff
import torch
from PIL import Image

from garak import _config
from garak.exception import ModelNameMissingError, GarakException
Expand Down Expand Up @@ -70,6 +69,7 @@ def __init__(self, name="", config_root=_config):
self._load_client()

def _load_client(self):
self._load_deps()
if hasattr(self, "generator") and self.generator is not None:
return
Comment on lines 71 to 74
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the interest of DRYness, I notice this exact code repeated across a number of the HFCompatible classes; perhaps we should refactor into HFCompatible?


Expand Down Expand Up @@ -104,6 +104,7 @@ def _load_client(self):
self._set_hf_context_len(self.generator.model.config)

def _clear_client(self):
self._clear_deps()
self.generator = None

def _format_chat_prompt(self, prompt: str) -> List[dict]:
Expand Down Expand Up @@ -158,19 +159,15 @@ class OptimumPipeline(Pipeline, HFCompatible):
generator_family_name = "NVIDIA Optimum Hugging Face 🤗 pipeline"
supports_multiple_generations = True
doc_uri = "https://huggingface.co/blog/optimum-nvidia"
extra_dependency_names = ["optimum"]

def _load_client(self):
self._load_deps()
if hasattr(self, "generator") and self.generator is not None:
return

try:
from optimum.nvidia.pipelines import pipeline
from transformers import set_seed
except Exception as e:
logging.exception(e)
raise GarakException(
f"Missing required dependencies for {self.__class__.__name__}"
)
pipeline = self.optimum.nvidia.pipelines.pipeline
from transformers import set_seed

if self.seed is not None:
set_seed(self.seed)
Expand Down Expand Up @@ -205,6 +202,7 @@ class ConversationalPipeline(Pipeline, HFCompatible):
supports_multiple_generations = True

def _load_client(self):
self._load_deps()
if hasattr(self, "generator") and self.generator is not None:
return

Expand Down Expand Up @@ -454,6 +452,7 @@ class Model(Pipeline, HFCompatible):
supports_multiple_generations = True

def _load_client(self):
self._load_deps()
if hasattr(self, "model") and self.model is not None:
return

Expand Down Expand Up @@ -501,6 +500,7 @@ def _load_client(self):
self.generation_config.pad_token_id = self.model.config.eos_token_id

def _clear_client(self):
self._clear_deps()
self.model = None
self.config = None
self.tokenizer = None
Expand Down Expand Up @@ -575,6 +575,8 @@ class LLaVA(Generator, HFCompatible):
NB. This should be use with strict modality matching - generate() doesn't
support text-only prompts."""

extra_dependency_names = ["PIL"]

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"max_tokens": 4000,
# "exist_tokens + max_new_tokens < 4K is the golden rule."
Expand Down Expand Up @@ -626,7 +628,7 @@ def generate(

text_prompt = prompt["text"]
try:
image_prompt = Image.open(prompt["image"])
image_prompt = self.PIL.Image.open(prompt["image"])
except FileNotFoundError:
raise FileNotFoundError(f"Cannot open image {prompt['image']}.")
except Exception as e:
Expand Down
7 changes: 2 additions & 5 deletions garak/generators/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import logging
from typing import List, Union


import langchain.llms

from garak import _config
from garak.generators.base import Generator

Expand Down Expand Up @@ -43,7 +40,7 @@ class LangChainLLMGenerator(Generator):
"presence_penalty": 0.0,
"stop": [],
}

extra_dependency_names = ["langchain.llms"]
generator_family_name = "LangChain"

def __init__(self, name="", config_root=_config):
Expand All @@ -55,7 +52,7 @@ def __init__(self, name="", config_root=_config):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we missing a call to self._load_deps()?

try:
# this might need some special handling to allow tests
llm = getattr(langchain.llms, self.name)()
llm = getattr(self.langchain_llms, self.name)()
except Exception as e:
logging.error("Failed to import Langchain module: %s", repr(e))
raise e
Expand Down
Loading