Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions langextract/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,167 @@ def parse_output(self, output: str) -> Any:
raise ValueError(
f'Failed to parse output as {self.format_type.name}: {str(e)}'
) from e


@dataclasses.dataclass(init=False)
class CustomLanguageModel(BaseLanguageModel):
"""Language model inference using Custom OpenAI-compatible API servers with structured output."""

model_id: str = 'z-ai/glm-4.5-air:free' # Default model
api_key: str | None = None
base_url: str = (
'https://openrouter.ai/api/v1' # Default to OpenRouter, but customizable
)
format_type: data.FormatType = data.FormatType.JSON
temperature: float = 0.0
max_workers: int = 10
_client: openai.OpenAI | None = dataclasses.field(
default=None, repr=False, compare=False
)
_extra_kwargs: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
)

def __init__(
self,
model_id: str = 'z-ai/glm-4.5-air:free', # Default model
api_key: str | None = None,
base_url: str = 'https://openrouter.ai/api/v1',
format_type: data.FormatType = data.FormatType.JSON,
temperature: float = 0.0,
max_workers: int = 10,
**kwargs,
) -> None:
"""Initialize the Custom OpenAI-compatible language model.

Args:
model_id: Model ID (format varies by provider).
api_key: API key for the service.
base_url: API endpoint URL. Examples:
- OpenRouter: 'https://openrouter.ai/api/v1'
- Together AI: 'https://api.together.xyz/v1'
format_type: Output format.
temperature: Sampling temperature.
max_workers: Maximum parallel API calls.
**kwargs: Additional parameters.
"""

self.model_id = model_id
self.api_key = api_key
self.base_url = base_url
self.format_type = format_type
self.temperature = temperature
self.max_workers = max_workers
self._extra_kwargs = kwargs or {}

if not self.api_key:
raise ValueError('API key not provided.')

self._client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)

super().__init__(
constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)
)

def _process_single_prompt(self, prompt: str, config: dict) -> ScoredOutput:
"""Process a single prompt and return a ScoredOutput."""
try:
# Prepare the system message for structured output
system_message = ''
if self.format_type == data.FormatType.JSON:
system_message = (
'You are a helpful assistant that responds in JSON format.'
)
elif self.format_type == data.FormatType.YAML:
system_message = (
'You are a helpful assistant that responds in YAML format.'
)

# Create the chat completion using the v1.x client API
response = self._client.chat.completions.create(
model=self.model_id,
messages=[
{'role': 'system', 'content': system_message},
{'role': 'user', 'content': prompt},
],
temperature=config.get('temperature', self.temperature),
max_tokens=config.get('max_output_tokens'),
top_p=config.get('top_p'),
n=1,
)

# Extract the response text using the v1.x response format
output_text = response.choices[0].message.content

return ScoredOutput(score=1.0, output=output_text)

except Exception as e:
raise InferenceOutputError(f'API error: {str(e)}') from e

def infer(
self, batch_prompts: Sequence[str], **kwargs
) -> Iterator[Sequence[ScoredOutput]]:
"""Runs inference on a list of prompts via OpenAI Compatible API format.

Args:
batch_prompts: A list of string prompts.
**kwargs: Additional generation params (temperature, top_p, etc.)

Yields:
Lists of ScoredOutputs.
"""
config = {
'temperature': kwargs.get('temperature', self.temperature),
}
if 'max_output_tokens' in kwargs:
config['max_output_tokens'] = kwargs['max_output_tokens']
if 'top_p' in kwargs:
config['top_p'] = kwargs['top_p']

# Use parallel processing for batches larger than 1
if len(batch_prompts) > 1 and self.max_workers > 1:
with concurrent.futures.ThreadPoolExecutor(
max_workers=min(self.max_workers, len(batch_prompts))
) as executor:
future_to_index = {
executor.submit(
self._process_single_prompt, prompt, config.copy()
): i
for i, prompt in enumerate(batch_prompts)
}

results: list[ScoredOutput | None] = [None] * len(batch_prompts)
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
results[index] = future.result()
except Exception as e:
raise InferenceOutputError(
f'Parallel inference error: {str(e)}'
) from e

for result in results:
if result is None:
raise InferenceOutputError('Failed to process one or more prompts')
yield [result]
else:
# Sequential processing for single prompt or worker
for prompt in batch_prompts:
result = self._process_single_prompt(prompt, config.copy())
yield [result]

def parse_output(self, output: str) -> Any:
"""Parses output as JSON or YAML.

Note: This expects raw JSON/YAML without code fences.
Code fence extraction is handled by resolver.py.
"""
try:
if self.format_type == data.FormatType.JSON:
return json.loads(output)
else:
return yaml.safe_load(output)
except Exception as e:
raise ValueError(
f'Failed to parse output as {self.format_type.name}: {str(e)}'
) from e
148 changes: 148 additions & 0 deletions tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,153 @@ def test_openai_temperature_zero(self, mock_openai_class):
)


class TestCustomLanguageModel(absltest.TestCase):

@mock.patch("openai.OpenAI")
def test_custom_infer(self, mock_openai_class):
# Mock the OpenAI client and chat completion response
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client

# Mock response structure for v1.x API
mock_response = mock.Mock()
mock_response.choices = [
mock.Mock(message=mock.Mock(content='{"name": "John", "age": 30}'))
]
mock_client.chat.completions.create.return_value = mock_response

# Create model instance
model = inference.CustomLanguageModel(
model_id="z-ai/glm-4.5-air:free",
api_key="test-api-key",
temperature=0.5,
)

# Test inference
batch_prompts = ["Extract name and age from: John is 30 years old"]
results = list(model.infer(batch_prompts))

# Verify OpenAI client was initialized with Custom Language Model base URL
mock_openai_class.assert_called_once_with(
api_key="test-api-key", base_url="https://openrouter.ai/api/v1"
)

# Verify API was called correctly
mock_client.chat.completions.create.assert_called_once_with(
model="z-ai/glm-4.5-air:free",
messages=[
{
"role": "system",
"content": (
"You are a helpful assistant that responds in JSON format."
),
},
{
"role": "user",
"content": "Extract name and age from: John is 30 years old",
},
],
temperature=0.5,
max_tokens=None,
top_p=None,
n=1,
)

# Check results
expected_results = [[
inference.ScoredOutput(score=1.0, output='{"name": "John", "age": 30}')
]]
self.assertEqual(results, expected_results)

def test_custom_parse_output_json(self):
model = inference.CustomLanguageModel(
api_key="test-key", format_type=data.FormatType.JSON
)

# Test valid JSON parsing
output = '{"key": "value", "number": 42}'
parsed = model.parse_output(output)
self.assertEqual(parsed, {"key": "value", "number": 42})

# Test invalid JSON
with self.assertRaises(ValueError) as context:
model.parse_output("invalid json")
self.assertIn("Failed to parse output as JSON", str(context.exception))

def test_custom_parse_output_yaml(self):
model = inference.CustomLanguageModel(
api_key="test-key", format_type=data.FormatType.YAML
)

# Test valid YAML parsing
output = "key: value\nnumber: 42"
parsed = model.parse_output(output)
self.assertEqual(parsed, {"key": "value", "number": 42})

# Test invalid YAML
with self.assertRaises(ValueError) as context:
model.parse_output("invalid: yaml: bad")
self.assertIn("Failed to parse output as YAML", str(context.exception))

def test_custom_no_api_key_raises_error(self):
with self.assertRaises(ValueError) as context:
inference.CustomLanguageModel(api_key=None)
self.assertEqual(str(context.exception), "API key not provided.")

@mock.patch("openai.OpenAI")
def test_custom_temperature_zero(self, mock_openai_class):
# Test that temperature=0.0 is properly passed through
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client

mock_response = mock.Mock()
mock_response.choices = [
mock.Mock(message=mock.Mock(content='{"result": "test"}'))
]
mock_client.chat.completions.create.return_value = mock_response

model = inference.CustomLanguageModel(
api_key="test-key", temperature=0.0 # Testing zero temperature
)

list(model.infer(["test prompt"]))

# Verify temperature=0.0 was passed to the API
mock_client.chat.completions.create.assert_called_with(
model="z-ai/glm-4.5-air:free",
messages=mock.ANY,
temperature=0.0,
max_tokens=None,
top_p=None,
n=1,
)

@mock.patch("openai.OpenAI")
def test_custom_custom_base_url(self, mock_openai_class):
# Test that custom base URL is properly used
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client

custom_base_url = "https://custom-language-model.example.com/api/v1"
model = inference.CustomLanguageModel(
api_key="test-key", base_url=custom_base_url
)

# Verify OpenAI client was initialized with custom base URL
mock_openai_class.assert_called_once_with(
api_key="test-key", base_url=custom_base_url
)

def test_custom_default_model(self):
# Test that default model is z-ai/glm-4.5-air:free
model = inference.CustomLanguageModel(api_key="test-key")
self.assertEqual(model.model_id, "z-ai/glm-4.5-air:free")

def test_custom_default_base_url(self):
# Test that default base URL is OpenRouter's API (used as default for custom language model)
model = inference.CustomLanguageModel(api_key="test-key")
self.assertEqual(model.base_url, "https://openrouter.ai/api/v1")


if __name__ == "__main__":
absltest.main()
Loading