1
1
import os
2
2
import sys
3
3
import time
4
- from typing import Dict , List , Optional , Union
4
+ from typing import Dict , List , Optional , Union , cast
5
5
6
6
import typer
7
7
import json
8
8
from rich .console import Console
9
9
from rich .syntax import Syntax
10
10
11
11
from guardrails .cli .guardrails import guardrails as gr_cli
12
- from guardrails .cli .hub .install import ( # JC: I don't like this import. Move fns?
13
- install_hub_module ,
14
- add_to_hub_inits ,
15
- run_post_install ,
16
- )
17
- from guardrails .cli .hub .utils import get_site_packages_location
18
- from guardrails .cli .server .hub_client import get_validator_manifest
19
12
from guardrails .cli .hub .template import get_template
20
13
21
14
console = Console ()
@@ -30,6 +23,11 @@ def create_command(
30
23
name : Optional [str ] = typer .Option (
31
24
default = None , help = "The name of the guard to define in the file."
32
25
),
26
+ local_models : Optional [bool ] = typer .Option (
27
+ None ,
28
+ "--install-local-models/--no-install-local-models" ,
29
+ help = "Install local models" ,
30
+ ),
33
31
filepath : str = typer .Option (
34
32
default = "config.py" ,
35
33
help = "The path to which the configuration file should be saved." ,
@@ -47,6 +45,8 @@ def create_command(
47
45
help = "Print out the validators to be installed without making any changes." ,
48
46
),
49
47
):
48
+ # fix pyright typing issue
49
+ validators = cast (str , validators )
50
50
filepath = check_filename (filepath )
51
51
52
52
if not validators and template is not None :
@@ -56,7 +56,11 @@ def create_command(
56
56
for validator in guard ["validators" ]:
57
57
validators_map [f"hub://{ validator ['id' ]} " ] = True
58
58
validators = "," .join (validators_map .keys ())
59
- installed_validators = split_and_install_validators (validators , dry_run ) # type: ignore
59
+ installed_validators = split_and_install_validators (
60
+ validators ,
61
+ local_models ,
62
+ dry_run ,
63
+ )
60
64
new_config_file = generate_template_config (
61
65
template_dict , installed_validators , template_file_name
62
66
)
@@ -67,7 +71,11 @@ def create_command(
67
71
)
68
72
sys .exit (1 )
69
73
else :
70
- installed_validators = split_and_install_validators (validators , dry_run ) # type: ignore
74
+ installed_validators = split_and_install_validators (
75
+ validators ,
76
+ local_models ,
77
+ dry_run ,
78
+ )
71
79
if name is None and validators :
72
80
name = "Guard"
73
81
if len (installed_validators ) > 0 :
@@ -137,53 +145,52 @@ def check_filename(filename: Union[str, os.PathLike]) -> str:
137
145
return filename # type: ignore
138
146
139
147
140
- def split_and_install_validators (validators : str , dry_run : bool = False ):
148
+ def split_and_install_validators (
149
+ validators : str , local_models : Union [bool , None ], dry_run : bool = False
150
+ ):
141
151
"""Given a comma-separated list of validators, check the hub to make sure
142
152
all of them exist, install them, and return a list of 'imports'.
143
153
144
154
If validators is empty, returns an empty list.
145
155
"""
156
+ from guardrails .hub .install import install
157
+
158
+ def install_local_models_confirm ():
159
+ return typer .confirm (
160
+ "This validator has a Guardrails AI inference endpoint available. "
161
+ "Would you still like to install the"
162
+ " local models for local inference?" ,
163
+ )
164
+
146
165
if not validators :
147
166
return []
148
167
149
- stripped_validators = list ()
150
- manifests = list ()
151
- site_packages = get_site_packages_location ()
168
+ manifest_exports = list ()
152
169
153
170
# Split by comma, strip start and end spaces, then make sure there's a hub prefix.
154
171
# If all that passes, download the manifest file so we know where to install.
155
172
# hub://blah -> blah, then download the manifest.
156
- console .print ("Checking validators..." )
157
- with console .status ("Checking validator manifests" ) as status :
158
- for v in validators .split ("," ):
159
- v = v .strip ()
160
- status .update (f"Prefetching { v } " )
161
- if not v .startswith ("hub://" ):
162
- console .print (
163
- f"WARNING: Validator { v } does not appear to be a valid URI."
164
- )
165
- sys .exit (- 1 )
166
- stripped_validator = v .lstrip ("hub://" )
167
- stripped_validators .append (stripped_validator )
168
- manifests .append (get_validator_manifest (stripped_validator ))
169
- console .print ("Success!" )
170
-
171
- # We should make sure they exist.
172
173
console .print ("Installing..." )
173
174
with console .status ("Installing validators" ) as status :
174
- for manifest , validator in zip (manifests , stripped_validators ):
175
- status .update (f"Installing { validator } " )
175
+ for v in validators .split ("," ):
176
+ validator_hub_uri = v .strip ()
177
+ status .update (f"Installing { v } " )
176
178
if not dry_run :
177
- install_hub_module (manifest , site_packages , quiet = True )
178
- run_post_install (manifest , site_packages )
179
- add_to_hub_inits (manifest , site_packages )
179
+ module = install (
180
+ package_uri = validator_hub_uri ,
181
+ install_local_models = local_models ,
182
+ quiet = True ,
183
+ install_local_models_confirm = install_local_models_confirm ,
184
+ )
185
+ exports = module .__validator_exports__
186
+ manifest_exports .append (exports [0 ])
180
187
else :
181
- console .print (f"Fake installing { validator } " )
188
+ console .print (f"Fake installing { validator_hub_uri } " )
182
189
time .sleep (1 )
183
190
console .print ("Success!" )
184
191
185
192
# Pull the hub information from each of the installed validators and return it.
186
- return [ manifest . exports [ 0 ] for manifest in manifests ]
193
+ return manifest_exports
187
194
188
195
189
196
def generate_config_file (validators : List [str ], name : Optional [str ] = None ) -> str :
0 commit comments