Skip to content

Feature: Add API_KEY option for CLI #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions synthetic_data_kit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def callback(
def system_check(
api_base: Optional[str] = typer.Option(
None, "--api-base", help="VLLM API base URL to check"
),
api_key: Optional[str] = typer.Option(
None, "--api-key", help="API key for authentication"
)
):
"""
Expand All @@ -54,10 +57,14 @@ def system_check(
# Get VLLM server details from args or config
vllm_config = get_vllm_config(ctx.config)
api_base = api_base or vllm_config.get("api_base")
api_key = api_key or vllm_config.get("api_key")
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

with console.status(f"Checking VLLM server at {api_base}..."):
try:
response = requests.get(f"{api_base}/models", timeout=2)
response = requests.get(f"{api_base}/models", headers=headers, timeout=2)
if response.status_code == 200:
console.print(f" VLLM server is running at {api_base}", style="green")
console.print(f"Available models: {response.json()}")
Expand Down Expand Up @@ -121,6 +128,9 @@ def create(
model: Optional[str] = typer.Option(
None, "--model", "-m", help="Model to use"
),
api_key: Optional[str] = typer.Option(
None, "--api-key", help="API key for authentication"
),
num_pairs: Optional[int] = typer.Option(
None, "--num-pairs", "-n", help="Target number of QA pairs to generate"
),
Expand All @@ -147,10 +157,14 @@ def create(
vllm_config = get_vllm_config(ctx.config)
api_base = api_base or vllm_config.get("api_base")
model = model or vllm_config.get("model")
api_key = api_key or vllm_config.get("api_key")
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

# Check server first
try:
response = requests.get(f"{api_base}/models", timeout=2)
response = requests.get(f"{api_base}/models", headers=headers, timeout=2)
if response.status_code != 200:
console.print(f"L Error: VLLM server not available at {api_base}", style="red")
console.print("Please start the VLLM server with:", style="yellow")
Expand Down Expand Up @@ -201,6 +215,9 @@ def curate(
model: Optional[str] = typer.Option(
None, "--model", "-m", help="Model to use"
),
api_key: Optional[str] = typer.Option(
None, "--api-key", help="API key for authentication"
),
verbose: bool = typer.Option(
False, "--verbose", "-v", help="Show detailed output"
),
Expand All @@ -214,10 +231,14 @@ def curate(
vllm_config = get_vllm_config(ctx.config)
api_base = api_base or vllm_config.get("api_base")
model = model or vllm_config.get("model")
api_key = api_key or vllm_config.get("api_key")
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

# Check server first
try:
response = requests.get(f"{api_base}/models", timeout=2)
response = requests.get(f"{api_base}/models", headers=headers, timeout=2)
if response.status_code != 200:
console.print(f"L Error: VLLM server not available at {api_base}", style="red")
console.print("Please start the VLLM server with:", style="yellow")
Expand Down
20 changes: 16 additions & 4 deletions synthetic_data_kit/models/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(self,
api_base: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: Optional[int] = None,
retry_delay: Optional[float] = None):
retry_delay: Optional[float] = None,
api_key: Optional[str] = None): # added api_key parameter
"""Initialize an OpenAI-compatible client that connects to a VLLM server

Args:
Expand All @@ -28,6 +29,7 @@ def __init__(self,
model_name: Override model name from config
max_retries: Override max retries from config
retry_delay: Override retry delay from config
api_key: Override API key from config
"""
# Load config
self.config = load_config(config_path)
Expand All @@ -38,6 +40,7 @@ def __init__(self,
self.model = model_name or vllm_config.get('model')
self.max_retries = max_retries or vllm_config.get('max_retries')
self.retry_delay = retry_delay or vllm_config.get('retry_delay')
self.api_key = api_key or vllm_config.get('api_key') # save API key

# Verify server is running
available, info = self._check_server()
Expand All @@ -47,7 +50,10 @@ def __init__(self,
def _check_server(self) -> tuple:
"""Check if the VLLM server is running and accessible"""
try:
response = requests.get(f"{self.api_base}/models", timeout=5)
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.get(f"{self.api_base}/models", headers=headers, timeout=5)
if response.status_code == 200:
return True, response.json()
return False, f"Server returned status code: {response.status_code}"
Expand Down Expand Up @@ -79,9 +85,12 @@ def chat_completion(self,
# Only print if verbose mode is enabled
if os.environ.get('SDK_VERBOSE', 'false').lower() == 'true':
print(f"Sending request to model {self.model}...")
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
f"{self.api_base}/chat/completions",
headers={"Content-Type": "application/json"},
headers=headers,
data=json.dumps(data),
timeout=180 # Increased timeout to 180 seconds
)
Expand Down Expand Up @@ -143,9 +152,12 @@ def batch_completion(self,
if verbose:
print(f"Sending batch request to model {self.model}...")

headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
f"{self.api_base}/chat/completions",
headers={"Content-Type": "application/json"},
headers=headers,
data=json.dumps(request_data),
timeout=180 # Increased timeout for batch processing
)
Expand Down