-
Notifications
You must be signed in to change notification settings - Fork 450
feature: deferred loading and requirement pruning #1199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 20 commits
3000d4c
757e0f3
9310d0a
dac569e
35e93fc
bf7f36b
6a39b0c
56c6182
3657e04
865d604
d61957d
8a7051e
60775f6
31e98d4
75babb7
83f551a
dd51196
b33a46c
de5b3f1
19c31fe
54fabc5
ffac714
97c8160
1d4e69c
6164bc5
85fb7c3
0402116
e287fe9
6339648
76b1774
ca133e4
8e8a5b9
aa7500a
4f2e5ef
69cfef2
a1da5ed
3a8605d
d2d17ad
13974b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ | |
|
||
import backoff | ||
import torch | ||
from PIL import Image | ||
|
||
from garak import _config | ||
from garak.exception import ModelNameMissingError, GarakException | ||
|
@@ -70,6 +69,7 @@ def __init__(self, name="", config_root=_config): | |
self._load_client() | ||
|
||
def _load_client(self): | ||
self._load_deps() | ||
if hasattr(self, "generator") and self.generator is not None: | ||
return | ||
Comment on lines
71
to
74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the interest of DRYness, I notice this exact code repeated across a number of the |
||
|
||
|
@@ -104,6 +104,7 @@ def _load_client(self): | |
self._set_hf_context_len(self.generator.model.config) | ||
|
||
def _clear_client(self): | ||
self._clear_deps() | ||
self.generator = None | ||
|
||
def _format_chat_prompt(self, prompt: str) -> List[dict]: | ||
|
@@ -158,19 +159,15 @@ class OptimumPipeline(Pipeline, HFCompatible): | |
generator_family_name = "NVIDIA Optimum Hugging Face 🤗 pipeline" | ||
supports_multiple_generations = True | ||
doc_uri = "https://huggingface.co/blog/optimum-nvidia" | ||
extra_dependency_names = ["optimum-nvidia"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor note that has little to do with this PR: it does drive me a bit nuts that the dependency name and the import statement so often do not match. |
||
|
||
def _load_client(self): | ||
self._load_deps() | ||
if hasattr(self, "generator") and self.generator is not None: | ||
return | ||
|
||
try: | ||
from optimum.nvidia.pipelines import pipeline | ||
from transformers import set_seed | ||
except Exception as e: | ||
logging.exception(e) | ||
raise GarakException( | ||
f"Missing required dependencies for {self.__class__.__name__}" | ||
) | ||
pipeline = self.optimum.nvidia.pipelines.pipeline | ||
from transformers import set_seed | ||
|
||
if self.seed is not None: | ||
set_seed(self.seed) | ||
|
@@ -205,6 +202,7 @@ class ConversationalPipeline(Pipeline, HFCompatible): | |
supports_multiple_generations = True | ||
|
||
def _load_client(self): | ||
self._load_deps() | ||
if hasattr(self, "generator") and self.generator is not None: | ||
return | ||
|
||
|
@@ -454,6 +452,7 @@ class Model(Pipeline, HFCompatible): | |
supports_multiple_generations = True | ||
|
||
def _load_client(self): | ||
self._load_deps() | ||
if hasattr(self, "model") and self.model is not None: | ||
return | ||
|
||
|
@@ -501,6 +500,7 @@ def _load_client(self): | |
self.generation_config.pad_token_id = self.model.config.eos_token_id | ||
|
||
def _clear_client(self): | ||
self._clear_deps() | ||
self.model = None | ||
self.config = None | ||
self.tokenizer = None | ||
|
@@ -575,6 +575,8 @@ class LLaVA(Generator, HFCompatible): | |
NB. This should be use with strict modality matching - generate() doesn't | ||
support text-only prompts.""" | ||
|
||
extra_dependency_names = ["PIL"] | ||
|
||
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||
"max_tokens": 4000, | ||
# "exist_tokens + max_new_tokens < 4K is the golden rule." | ||
|
@@ -626,7 +628,7 @@ def generate( | |
|
||
text_prompt = prompt["text"] | ||
try: | ||
image_prompt = Image.open(prompt["image"]) | ||
image_prompt = self.PIL.Image.open(prompt["image"]) | ||
except FileNotFoundError: | ||
raise FileNotFoundError(f"Cannot open image {prompt['image']}.") | ||
except Exception as e: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,6 @@ | |
import logging | ||
from typing import List, Union | ||
|
||
|
||
import langchain.llms | ||
|
||
from garak import _config | ||
from garak.generators.base import Generator | ||
|
||
|
@@ -43,7 +40,7 @@ class LangChainLLMGenerator(Generator): | |
"presence_penalty": 0.0, | ||
"stop": [], | ||
} | ||
|
||
extra_dependency_names = ["langchain.llms"] | ||
generator_family_name = "LangChain" | ||
|
||
def __init__(self, name="", config_root=_config): | ||
|
@@ -53,14 +50,7 @@ def __init__(self, name="", config_root=_config): | |
|
||
super().__init__(self.name, config_root=config_root) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we missing a call to |
||
try: | ||
# this might need some special handling to allow tests | ||
llm = getattr(langchain.llms, self.name)() | ||
except Exception as e: | ||
logging.error("Failed to import Langchain module: %s", repr(e)) | ||
raise e | ||
|
||
self.generator = llm | ||
self.generator = getattr(self.langchain_llms, self.name)() | ||
|
||
def _call_model( | ||
self, prompt: str, generations_this_call: int = 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be in
Configurable
instead, since it can/should be used across all base classes?