diff --git a/synthetic_data_kit/cli.py b/synthetic_data_kit/cli.py index 98df553..d224544 100644 --- a/synthetic_data_kit/cli.py +++ b/synthetic_data_kit/cli.py @@ -46,6 +46,9 @@ def callback( def system_check( api_base: Optional[str] = typer.Option( None, "--api-base", help="VLLM API base URL to check" + ), + api_key: Optional[str] = typer.Option( + None, "--api-key", help="API key for authentication" ) ): """ @@ -54,10 +57,14 @@ def system_check( # Get VLLM server details from args or config vllm_config = get_vllm_config(ctx.config) api_base = api_base or vllm_config.get("api_base") + api_key = api_key or vllm_config.get("api_key") + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" with console.status(f"Checking VLLM server at {api_base}..."): try: - response = requests.get(f"{api_base}/models", timeout=2) + response = requests.get(f"{api_base}/models", headers=headers, timeout=2) if response.status_code == 200: console.print(f" VLLM server is running at {api_base}", style="green") console.print(f"Available models: {response.json()}") @@ -121,6 +128,9 @@ def create( model: Optional[str] = typer.Option( None, "--model", "-m", help="Model to use" ), + api_key: Optional[str] = typer.Option( + None, "--api-key", help="API key for authentication" + ), num_pairs: Optional[int] = typer.Option( None, "--num-pairs", "-n", help="Target number of QA pairs to generate" ), @@ -147,10 +157,14 @@ def create( vllm_config = get_vllm_config(ctx.config) api_base = api_base or vllm_config.get("api_base") model = model or vllm_config.get("model") + api_key = api_key or vllm_config.get("api_key") + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" # Check server first try: - response = requests.get(f"{api_base}/models", timeout=2) + response = requests.get(f"{api_base}/models", headers=headers, timeout=2) if response.status_code != 200: console.print(f"L Error: VLLM server not available at {api_base}", style="red") console.print("Please start the VLLM server with:", style="yellow") @@ -201,6 +215,9 @@ def curate( model: Optional[str] = typer.Option( None, "--model", "-m", help="Model to use" ), + api_key: Optional[str] = typer.Option( + None, "--api-key", help="API key for authentication" + ), verbose: bool = typer.Option( False, "--verbose", "-v", help="Show detailed output" ), @@ -214,10 +231,14 @@ def curate( vllm_config = get_vllm_config(ctx.config) api_base = api_base or vllm_config.get("api_base") model = model or vllm_config.get("model") + api_key = api_key or vllm_config.get("api_key") + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" # Check server first try: - response = requests.get(f"{api_base}/models", timeout=2) + response = requests.get(f"{api_base}/models", headers=headers, timeout=2) if response.status_code != 200: console.print(f"L Error: VLLM server not available at {api_base}", style="red") console.print("Please start the VLLM server with:", style="yellow") diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index e5d2c27..a7895fb 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -19,7 +19,8 @@ def __init__(self, api_base: Optional[str] = None, model_name: Optional[str] = None, max_retries: Optional[int] = None, - retry_delay: Optional[float] = None): + retry_delay: Optional[float] = None, + api_key: Optional[str] = None): # added api_key parameter """Initialize an OpenAI-compatible client that connects to a VLLM server Args: @@ -28,6 +29,7 @@ def __init__(self, model_name: Override model name from config max_retries: Override max retries from config retry_delay: Override retry delay from config + api_key: Override API key from config """ # Load config self.config = load_config(config_path) @@ -38,6 +40,7 @@ def __init__(self, self.model = model_name or vllm_config.get('model') self.max_retries = max_retries or vllm_config.get('max_retries') self.retry_delay = retry_delay or vllm_config.get('retry_delay') + self.api_key = api_key or vllm_config.get('api_key') # save API key # Verify server is running available, info = self._check_server() @@ -47,7 +50,10 @@ def __init__(self, def _check_server(self) -> tuple: """Check if the VLLM server is running and accessible""" try: - response = requests.get(f"{self.api_base}/models", timeout=5) + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + response = requests.get(f"{self.api_base}/models", headers=headers, timeout=5) if response.status_code == 200: return True, response.json() return False, f"Server returned status code: {response.status_code}" @@ -79,9 +85,12 @@ def chat_completion(self, # Only print if verbose mode is enabled if os.environ.get('SDK_VERBOSE', 'false').lower() == 'true': print(f"Sending request to model {self.model}...") + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" response = requests.post( f"{self.api_base}/chat/completions", - headers={"Content-Type": "application/json"}, + headers=headers, data=json.dumps(data), timeout=180 # Increased timeout to 180 seconds ) @@ -143,9 +152,12 @@ def batch_completion(self, if verbose: print(f"Sending batch request to model {self.model}...") + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" response = requests.post( f"{self.api_base}/chat/completions", - headers={"Content-Type": "application/json"}, + headers=headers, data=json.dumps(request_data), timeout=180 # Increased timeout for batch processing )