Skip to content

Commit c213a8b

Browse files
authored
Merge pull request #995 from guardrails-ai/feat/cli-create-install-local-models
Create CLI command to respect install local models
2 parents bf26ac6 + adea449 commit c213a8b

File tree

5 files changed

+86
-316
lines changed

5 files changed

+86
-316
lines changed

guardrails/cli/create.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
import os
22
import sys
33
import time
4-
from typing import Dict, List, Optional, Union
4+
from typing import Dict, List, Optional, Union, cast
55

66
import typer
77
import json
88
from rich.console import Console
99
from rich.syntax import Syntax
1010

1111
from guardrails.cli.guardrails import guardrails as gr_cli
12-
from guardrails.cli.hub.install import ( # JC: I don't like this import. Move fns?
13-
install_hub_module,
14-
add_to_hub_inits,
15-
run_post_install,
16-
)
17-
from guardrails.cli.hub.utils import get_site_packages_location
18-
from guardrails.cli.server.hub_client import get_validator_manifest
1912
from guardrails.cli.hub.template import get_template
2013

2114
console = Console()
@@ -30,6 +23,11 @@ def create_command(
3023
name: Optional[str] = typer.Option(
3124
default=None, help="The name of the guard to define in the file."
3225
),
26+
local_models: Optional[bool] = typer.Option(
27+
None,
28+
"--install-local-models/--no-install-local-models",
29+
help="Install local models",
30+
),
3331
filepath: str = typer.Option(
3432
default="config.py",
3533
help="The path to which the configuration file should be saved.",
@@ -47,6 +45,8 @@ def create_command(
4745
help="Print out the validators to be installed without making any changes.",
4846
),
4947
):
48+
# fix pyright typing issue
49+
validators = cast(str, validators)
5050
filepath = check_filename(filepath)
5151

5252
if not validators and template is not None:
@@ -56,7 +56,11 @@ def create_command(
5656
for validator in guard["validators"]:
5757
validators_map[f"hub://{validator['id']}"] = True
5858
validators = ",".join(validators_map.keys())
59-
installed_validators = split_and_install_validators(validators, dry_run) # type: ignore
59+
installed_validators = split_and_install_validators(
60+
validators,
61+
local_models,
62+
dry_run,
63+
)
6064
new_config_file = generate_template_config(
6165
template_dict, installed_validators, template_file_name
6266
)
@@ -67,7 +71,11 @@ def create_command(
6771
)
6872
sys.exit(1)
6973
else:
70-
installed_validators = split_and_install_validators(validators, dry_run) # type: ignore
74+
installed_validators = split_and_install_validators(
75+
validators,
76+
local_models,
77+
dry_run,
78+
)
7179
if name is None and validators:
7280
name = "Guard"
7381
if len(installed_validators) > 0:
@@ -137,53 +145,52 @@ def check_filename(filename: Union[str, os.PathLike]) -> str:
137145
return filename # type: ignore
138146

139147

140-
def split_and_install_validators(validators: str, dry_run: bool = False):
148+
def split_and_install_validators(
149+
validators: str, local_models: Union[bool, None], dry_run: bool = False
150+
):
141151
"""Given a comma-separated list of validators, check the hub to make sure
142152
all of them exist, install them, and return a list of 'imports'.
143153
144154
If validators is empty, returns an empty list.
145155
"""
156+
from guardrails.hub.install import install
157+
158+
def install_local_models_confirm():
159+
return typer.confirm(
160+
"This validator has a Guardrails AI inference endpoint available. "
161+
"Would you still like to install the"
162+
" local models for local inference?",
163+
)
164+
146165
if not validators:
147166
return []
148167

149-
stripped_validators = list()
150-
manifests = list()
151-
site_packages = get_site_packages_location()
168+
manifest_exports = list()
152169

153170
# Split by comma, strip start and end spaces, then make sure there's a hub prefix.
154171
# If all that passes, download the manifest file so we know where to install.
155172
# hub://blah -> blah, then download the manifest.
156-
console.print("Checking validators...")
157-
with console.status("Checking validator manifests") as status:
158-
for v in validators.split(","):
159-
v = v.strip()
160-
status.update(f"Prefetching {v}")
161-
if not v.startswith("hub://"):
162-
console.print(
163-
f"WARNING: Validator {v} does not appear to be a valid URI."
164-
)
165-
sys.exit(-1)
166-
stripped_validator = v.lstrip("hub://")
167-
stripped_validators.append(stripped_validator)
168-
manifests.append(get_validator_manifest(stripped_validator))
169-
console.print("Success!")
170-
171-
# We should make sure they exist.
172173
console.print("Installing...")
173174
with console.status("Installing validators") as status:
174-
for manifest, validator in zip(manifests, stripped_validators):
175-
status.update(f"Installing {validator}")
175+
for v in validators.split(","):
176+
validator_hub_uri = v.strip()
177+
status.update(f"Installing {v}")
176178
if not dry_run:
177-
install_hub_module(manifest, site_packages, quiet=True)
178-
run_post_install(manifest, site_packages)
179-
add_to_hub_inits(manifest, site_packages)
179+
module = install(
180+
package_uri=validator_hub_uri,
181+
install_local_models=local_models,
182+
quiet=True,
183+
install_local_models_confirm=install_local_models_confirm,
184+
)
185+
exports = module.__validator_exports__
186+
manifest_exports.append(exports[0])
180187
else:
181-
console.print(f"Fake installing {validator}")
188+
console.print(f"Fake installing {validator_hub_uri}")
182189
time.sleep(1)
183190
console.print("Success!")
184191

185192
# Pull the hub information from each of the installed validators and return it.
186-
return [manifest.exports[0] for manifest in manifests]
193+
return manifest_exports
187194

188195

189196
def generate_config_file(validators: List[str], name: Optional[str] = None) -> str:

guardrails/cli/hub/install.py

Lines changed: 1 addition & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1,186 +1,10 @@
1-
import os
2-
import subprocess
31
import sys
4-
from typing import List, Literal, Optional
2+
from typing import Optional
53

64
import typer
75

8-
from guardrails.classes.generic import Stack
96
from guardrails.cli.hub.hub import hub_command
10-
11-
from guardrails.cli.hub.utils import (
12-
get_hub_directory,
13-
get_org_and_package_dirs,
14-
pip_process,
15-
)
167
from guardrails.cli.logger import logger
17-
from guardrails.cli.server.module_manifest import ModuleManifest
18-
19-
20-
def removesuffix(string: str, suffix: str) -> str:
21-
if sys.version_info.minor >= 9:
22-
return string.removesuffix(suffix) # type: ignore
23-
else:
24-
if string.endswith(suffix):
25-
return string[: -len(suffix)]
26-
return string
27-
28-
29-
string_format: Literal["string"] = "string"
30-
json_format: Literal["json"] = "json"
31-
32-
33-
# NOTE: I don't like this but don't see another way without
34-
# shimming the init file with all hub validators
35-
def add_to_hub_inits(manifest: ModuleManifest, site_packages: str):
36-
org_package = get_org_and_package_dirs(manifest)
37-
exports: List[str] = manifest.exports or []
38-
sorted_exports = sorted(exports, reverse=True)
39-
module_name = manifest.module_name
40-
relative_path = ".".join([*org_package, module_name])
41-
import_line = (
42-
f"from guardrails.hub.{relative_path} import {', '.join(sorted_exports)}"
43-
)
44-
45-
hub_init_location = os.path.join(site_packages, "guardrails", "hub", "__init__.py")
46-
with open(hub_init_location, "a+") as hub_init:
47-
hub_init.seek(0, 0)
48-
content = hub_init.read()
49-
if import_line in content:
50-
hub_init.close()
51-
else:
52-
hub_init.seek(0, 2)
53-
if len(content) > 0:
54-
hub_init.write("\n")
55-
hub_init.write(import_line)
56-
hub_init.close()
57-
58-
namespace = org_package[0]
59-
namespace_init_location = os.path.join(
60-
site_packages, "guardrails", "hub", namespace, "__init__.py"
61-
)
62-
if os.path.isfile(namespace_init_location):
63-
with open(namespace_init_location, "a+") as namespace_init:
64-
namespace_init.seek(0, 0)
65-
content = namespace_init.read()
66-
if import_line in content:
67-
namespace_init.close()
68-
else:
69-
namespace_init.seek(0, 2)
70-
if len(content) > 0:
71-
namespace_init.write("\n")
72-
namespace_init.write(import_line)
73-
namespace_init.close()
74-
else:
75-
with open(namespace_init_location, "w") as namespace_init:
76-
namespace_init.write(import_line)
77-
namespace_init.close()
78-
79-
80-
def run_post_install(manifest: ModuleManifest, site_packages: str):
81-
org_package = get_org_and_package_dirs(manifest)
82-
post_install_script = manifest.post_install
83-
if not post_install_script:
84-
return
85-
86-
module_name = manifest.module_name
87-
relative_path = os.path.join(
88-
site_packages,
89-
"guardrails",
90-
"hub",
91-
*org_package,
92-
module_name,
93-
post_install_script,
94-
)
95-
96-
if os.path.isfile(relative_path):
97-
try:
98-
logger.debug("running post install script...")
99-
command = [sys.executable, relative_path]
100-
subprocess.check_output(command)
101-
except subprocess.CalledProcessError as exc:
102-
logger.error(
103-
(
104-
f"Failed to run post install script for {manifest.id}\n"
105-
f"Exit code: {exc.returncode}\n"
106-
f"stdout: {exc.output}"
107-
)
108-
)
109-
sys.exit(1)
110-
except Exception as e:
111-
logger.error(
112-
f"An unexpected exception occurred while running the post install script for {manifest.id}!", # noqa
113-
e,
114-
)
115-
sys.exit(1)
116-
117-
118-
def get_install_url(manifest: ModuleManifest) -> str:
119-
repo = manifest.repository
120-
repo_url = repo.url
121-
branch = repo.branch
122-
123-
git_url = repo_url
124-
if not repo_url.startswith("git+"):
125-
git_url = f"git+{repo_url}"
126-
127-
if branch is not None:
128-
git_url = f"{git_url}@{branch}"
129-
130-
return git_url
131-
132-
133-
def install_hub_module(
134-
module_manifest: ModuleManifest, site_packages: str, quiet: bool = False
135-
):
136-
install_url = get_install_url(module_manifest)
137-
install_directory = get_hub_directory(module_manifest, site_packages)
138-
139-
pip_flags = [f"--target={install_directory}", "--no-deps"]
140-
if quiet:
141-
pip_flags.append("-q")
142-
143-
# Install validator module in namespaced directory under guardrails.hub
144-
download_output = pip_process("install", install_url, pip_flags, quiet=quiet)
145-
if not quiet:
146-
logger.info(download_output)
147-
148-
# Install validator module's dependencies in normal site-packages directory
149-
inspect_output = pip_process(
150-
"inspect",
151-
flags=[f"--path={install_directory}"],
152-
format=json_format,
153-
quiet=quiet,
154-
no_color=True,
155-
)
156-
157-
# throw if inspect_output is a string. Mostly for pyright
158-
if isinstance(inspect_output, str):
159-
logger.error("Failed to inspect the installed package!")
160-
sys.exit(1)
161-
162-
dependencies = (
163-
Stack(*inspect_output.get("installed", []))
164-
.at(0, {})
165-
.get("metadata", {}) # type: ignore
166-
.get("requires_dist", []) # type: ignore
167-
)
168-
requirements = list(filter(lambda dep: "extra" not in dep, dependencies))
169-
for req in requirements:
170-
if "git+" in req:
171-
install_spec = req.replace(" ", "")
172-
dep_install_output = pip_process("install", install_spec, quiet=quiet)
173-
if not quiet:
174-
logger.info(dep_install_output)
175-
else:
176-
req_info = Stack(*req.split(" "))
177-
name = req_info.at(0, "").strip() # type: ignore
178-
versions = req_info.at(1, "").strip("()") # type: ignore
179-
if name:
180-
install_spec = name if not versions else f"{name}{versions}"
181-
dep_install_output = pip_process("install", install_spec, quiet=quiet)
182-
if not quiet:
183-
logger.info(dep_install_output)
1848

1859

18610
@hub_command.command()

guardrails/hub/install.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from contextlib import contextmanager
22
from string import Template
3-
from typing import Callable
3+
from typing import Callable, cast
44

5-
from guardrails.hub.validator_package_service import ValidatorPackageService
5+
from guardrails.hub.validator_package_service import (
6+
ValidatorPackageService,
7+
ValidatorModuleType,
8+
)
69
from guardrails.classes.credentials import Credentials
710

811
from guardrails.cli.hub.console import console
@@ -33,7 +36,7 @@ def install(
3336
install_local_models=None,
3437
quiet: bool = True,
3538
install_local_models_confirm: Callable = default_local_models_confirm,
36-
):
39+
) -> ValidatorModuleType:
3740
"""
3841
Install a validator package from a hub URI.
3942
@@ -49,6 +52,9 @@ def install(
4952
5053
Examples:
5154
>>> RegexMatch = install("hub://guardrails/regex_match").RegexMatch
55+
56+
>>> install("hub://guardrails/regex_match);
57+
>>> import guardrails.hub.regex_match as regex_match
5258
"""
5359

5460
verbose_printer = console.print
@@ -122,7 +128,10 @@ def install(
122128
ValidatorPackageService.add_to_hub_inits(module_manifest, site_packages)
123129

124130
# 5. Get Validator Class for the installed module
125-
validators = ValidatorPackageService.get_validator_from_manifest(module_manifest)
131+
installed_module = ValidatorPackageService.get_validator_from_manifest(
132+
module_manifest
133+
)
134+
installed_module = cast(ValidatorModuleType, installed_module)
126135

127136
# Print success messages
128137
cli_logger.info("Installation complete")
@@ -152,4 +161,7 @@ def install(
152161
quiet_printer(success_message_cli) # type: ignore
153162
cli_logger.log(level=LEVELS.get("SPAM"), msg=success_message_logger) # type: ignore
154163

155-
return validators
164+
# Not a fan of this but allows the installation to be used in create command as is
165+
installed_module.__validator_exports__ = module_manifest.exports
166+
167+
return installed_module

0 commit comments

Comments
 (0)