Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion langextract/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class OpenAILanguageModel(BaseLanguageModel):

model_id: str = 'gpt-4o-mini'
api_key: str | None = None
base_url: str | None = None
organization: str | None = None
format_type: data.FormatType = data.FormatType.JSON
temperature: float = 0.0
Expand All @@ -421,6 +422,7 @@ def __init__(
self,
model_id: str = 'gpt-4o-mini',
api_key: str | None = None,
base_url: str | None = None,
organization: str | None = None,
format_type: data.FormatType = data.FormatType.JSON,
temperature: float = 0.0,
Expand All @@ -432,6 +434,7 @@ def __init__(
Args:
model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o').
api_key: API key for OpenAI service.
base_url: Base URL for OpenAI service.
organization: Optional OpenAI organization ID.
format_type: Output format (JSON or YAML).
temperature: Sampling temperature.
Expand All @@ -441,6 +444,7 @@ def __init__(
"""
self.model_id = model_id
self.api_key = api_key
self.base_url = base_url
self.organization = organization
self.format_type = format_type
self.temperature = temperature
Expand All @@ -452,7 +456,9 @@ def __init__(

# Initialize the OpenAI client
self._client = openai.OpenAI(
api_key=self.api_key, organization=self.organization
api_key=self.api_key,
base_url=self.base_url,
organization=self.organization,
)

super().__init__(
Expand Down
21 changes: 17 additions & 4 deletions tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized

from langextract import data
from langextract import inference
Expand Down Expand Up @@ -94,10 +95,16 @@ def test_ollama_infer(self, mock_ollama_query):
self.assertEqual(results, expected_results)


class TestOpenAILanguageModel(absltest.TestCase):
class TestOpenAILanguageModelInference(parameterized.TestCase):

@parameterized.named_parameters(
("without", "test-api-key", None, "gpt-4o-mini", 0.5),
("with", "test-api-key", "http://127.0.0.1:9001/v1", "gpt-4o-mini", 0.5),
)
@mock.patch("openai.OpenAI")
def test_openai_infer(self, mock_openai_class):
def test_openai_infer_with_parameters(
self, api_key, base_url, model_id, temperature, mock_openai_class
):
# Mock the OpenAI client and chat completion response
mock_client = mock.Mock()
mock_openai_class.return_value = mock_client
Expand All @@ -111,7 +118,10 @@ def test_openai_infer(self, mock_openai_class):

# Create model instance
model = inference.OpenAILanguageModel(
model_id="gpt-4o-mini", api_key="test-api-key", temperature=0.5
model_id=model_id,
api_key=api_key,
base_url=base_url,
temperature=temperature,
)

# Test inference
Expand All @@ -133,7 +143,7 @@ def test_openai_infer(self, mock_openai_class):
"content": "Extract name and age from: John is 30 years old",
},
],
temperature=0.5,
temperature=temperature,
max_tokens=None,
top_p=None,
n=1,
Expand All @@ -145,6 +155,9 @@ def test_openai_infer(self, mock_openai_class):
]]
self.assertEqual(results, expected_results)


class TestOpenAILanguageModel(absltest.TestCase):

def test_openai_parse_output_json(self):
model = inference.OpenAILanguageModel(
api_key="test-key", format_type=data.FormatType.JSON
Expand Down