-
Notifications
You must be signed in to change notification settings - Fork 474
feature: deferred loading and requirement pruning #1199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 38 commits
3000d4c
757e0f3
9310d0a
dac569e
35e93fc
bf7f36b
6a39b0c
56c6182
3657e04
865d604
d61957d
8a7051e
60775f6
31e98d4
75babb7
83f551a
dd51196
b33a46c
de5b3f1
19c31fe
54fabc5
ffac714
97c8160
1d4e69c
6164bc5
85fb7c3
0402116
e287fe9
6339648
76b1774
ca133e4
8e8a5b9
aa7500a
4f2e5ef
69cfef2
a1da5ed
3a8605d
d2d17ad
13974b8
dc83929
d527650
8c46730
06180b6
7c22dea
ce23d70
8ab94bd
bb67a3e
b45ba35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,7 @@ jobs: | |
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt | ||
pip install . | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should install all dependencies as the cache file needs to include all plugins. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
- name: Build a local cache | ||
run: | | ||
export TZ=UTC | ||
|
leondz marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,7 +49,7 @@ jobs: | |
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install --no-cache-dir -r requirements.txt | ||
pip install --no-cache-dir .[tests] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either |
||
python -m pip cache purge | ||
|
||
- name: Restore test cache artifacts | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -400,6 +400,23 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object: | |
) from ve | ||
else: | ||
return False | ||
|
||
full_plugin_name = ".".join((category, module_name, plugin_class_name)) | ||
|
||
# check cache for optional imports | ||
if category in PLUGIN_TYPES: | ||
extra_dependency_names = PluginCache.instance()[category][full_plugin_name][ | ||
"extra_dependency_names" | ||
] | ||
if len(extra_dependency_names) > 0: | ||
for dependency_module_name in extra_dependency_names: | ||
for dependency_path in [ # support both plain names and also multi-point names e.g. langchain.llms | ||
".".join(dependency_module_name.split(".")[: n + 1]) | ||
for n in range(dependency_module_name.count(".") + 1) | ||
]: | ||
if importlib.util.find_spec(dependency_path) is None: | ||
_import_failed(dependency_path, full_plugin_name) | ||
Comment on lines
+408
to
+420
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this really the best way to do this? Perhaps we just enforce lazy loading throughout instead? I'm not sure. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I guess that is what we're doing. This is the hazard of doing code reviews linearly, I suppose. |
||
|
||
module_path = f"garak.{category}.{module_name}" | ||
try: | ||
mod = importlib.import_module(module_path) | ||
|
@@ -426,6 +443,7 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object: | |
if plugin_instance is None: | ||
plugin_instance = klass(config_root=config_root) | ||
PluginProvider.storeInstance(plugin_instance, config_root) | ||
|
||
except Exception as e: | ||
logging.warning( | ||
"Exception instantiating %s.%s: %s", | ||
|
@@ -440,3 +458,20 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object: | |
return False | ||
|
||
return plugin_instance | ||
|
||
|
||
def load_optional_module(module_name: str): | ||
try: | ||
m = importlib.import_module(module_name) | ||
except ModuleNotFoundError: | ||
requesting_module = Path(inspect.stack()[1].filename).name.replace(".py", "") | ||
_import_failed(module_name, requesting_module) | ||
return m | ||
|
||
|
||
def _import_failed(import_module: str, calling_module: str): | ||
msg = f"⛔ Plugin '{calling_module}' requires Python module '{import_module}' but this isn't installed/available." | ||
hint = f"💡 Try 'pip install {import_module}' to get it." | ||
logging.critical(msg) | ||
print(msg + "\n" + hint) | ||
raise ModuleNotFoundError(msg) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ class ModelNameMissingError(GarakException): | |
"""A generator requires model_name to be set, but it wasn't""" | ||
|
||
|
||
class GarakBackoffTrigger(GarakException): | ||
class GeneratorBackoffTrigger(GarakException): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why rename this? The original name seems clear enough to me, I there some envisioned case where two layers of backoff would need to differentiate the source? |
||
"""Thrown when backoff should be triggered""" | ||
|
||
|
||
|
@@ -36,3 +36,4 @@ class ConfigFailure(GarakException): | |
|
||
class PayloadFailure(GarakException): | ||
"""Problem instantiating/using payloads""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,8 @@ class Generator(Configurable): | |
supports_multiple_generations = ( | ||
False # can more than one generation be extracted per request? | ||
) | ||
# list of strings naming modules required but not explicitly in garak by default | ||
extra_dependency_names = [] | ||
|
||
def __init__(self, name="", config_root=_config): | ||
self._load_config(config_root) | ||
|
@@ -63,6 +65,29 @@ def __init__(self, name="", config_root=_config): | |
f"🦜 loading {Style.BRIGHT}{Fore.LIGHTMAGENTA_EX}generator{Style.RESET_ALL}: {self.generator_family_name}: {self.name}" | ||
) | ||
logging.info("generator init: %s", self) | ||
self._load_deps() | ||
|
||
def _load_deps(self): | ||
# load external dependencies. should be invoked at construction and | ||
# in _client_load (if used) | ||
for extra_dependency in self.extra_dependency_names: | ||
extra_dep_name = extra_dependency.replace(".", "_").replace("-", "_") | ||
if ( | ||
not hasattr(self, extra_dep_name) | ||
or getattr(self, extra_dep_name) is None | ||
): | ||
setattr( | ||
self, | ||
extra_dep_name, | ||
garak._plugins.load_optional_module(extra_dependency), | ||
) | ||
|
||
def _clear_deps(self): | ||
# unload external dependencies from class. should be invoked before | ||
# serialisation, esp. in _clear_client (if used) | ||
for extra_dependency in self.extra_dependency_names: | ||
extra_dep_name = extra_dependency.replace(".", "_") | ||
setattr(self, extra_dep_name, None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. definitely makes sense to factor it up, thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the generator implementation is special, because generators have client load/unload for probe parallelisation. will slate upfactoring for a second iteration. |
||
|
||
def _call_model( | ||
self, prompt: str, generations_this_call: int = 1 | ||
|
@@ -101,7 +126,7 @@ def _prune_skip_sequences(self, outputs: List[str | None]) -> List[str | None]: | |
) | ||
rx_missing_final = re.escape(self.skip_seq_start) + ".*?$" | ||
rx_missing_start = ".*?" + re.escape(self.skip_seq_end) | ||
|
||
if self.skip_seq_start == "": | ||
complete_seqs_removed = [ | ||
( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think linter will want to have all possible dependencies.