Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions langextract/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from langextract import exceptions
from langextract import inference
from langextract import providers
from langextract.providers import registry


Expand Down Expand Up @@ -107,6 +108,9 @@ def create_model(config: ModelConfig) -> inference.BaseLanguageModel:
if config.provider:
provider_class = registry.resolve_provider(config.provider)
else:
# Load providers before pattern matching
providers.load_builtins_once()
providers.load_plugins_once()
provider_class = registry.resolve(config.model_id)
except (ModuleNotFoundError, ImportError) as e:
raise exceptions.InferenceConfigError(
Expand Down
37 changes: 26 additions & 11 deletions langextract/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,32 @@

from langextract.providers import registry

# Track whether plugins have been loaded
# Track provider loading for lazy initialization
_PLUGINS_LOADED = False
_BUILTINS_LOADED = False


def load_builtins_once() -> None:
"""Load built-in providers to register their patterns.
Idempotent function that ensures provider patterns are available
for model resolution.
"""
global _BUILTINS_LOADED # pylint: disable=global-statement
if _BUILTINS_LOADED:
return

# pylint: disable=import-outside-toplevel
from langextract.providers import gemini # noqa: F401
from langextract.providers import ollama # noqa: F401

try:
from langextract.providers import openai # noqa: F401
except ImportError:
logging.debug("OpenAI provider not available (optional dependency)")
# pylint: enable=import-outside-toplevel

_BUILTINS_LOADED = True


def load_plugins_once() -> None:
Expand Down Expand Up @@ -73,13 +97,4 @@ def load_plugins_once() -> None:
)


# pylint: disable=wrong-import-position
from langextract.providers import gemini # noqa: F401
from langextract.providers import ollama # noqa: F401

try:
from langextract.providers import openai # noqa: F401
except ImportError:
pass

__all__ = ["registry", "load_plugins_once"]
__all__ = ["registry", "load_plugins_once", "load_builtins_once"]
2 changes: 2 additions & 0 deletions langextract/providers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def resolve(model_id: str) -> type[inference.BaseLanguageModel]:
# pylint: disable=import-outside-toplevel
from langextract import providers

providers.load_builtins_once()
providers.load_plugins_once()

sorted_entries = sorted(_ENTRIES, key=lambda e: e.priority, reverse=True)
Expand Down Expand Up @@ -154,6 +155,7 @@ class name (e.g., "GeminiLanguageModel").
# pylint: disable=import-outside-toplevel
from langextract import providers

providers.load_builtins_once()
providers.load_plugins_once()

for entry in _ENTRIES:
Expand Down
28 changes: 28 additions & 0 deletions tests/factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,34 @@ def infer_batch(self, prompts, batch_size=32):

self.assertEqual(model.base_url, "http://custom:11434")

def test_ollama_models_select_without_api_keys(self):
"""Test that Ollama models resolve without API keys or explicit type."""

@registry.register(
r"^llama", r"^gemma", r"^mistral", r"^qwen", priority=100
)
class FakeOllamaProvider(inference.BaseLanguageModel):

def __init__(self, model_id, **kwargs):
self.model_id = model_id
super().__init__()

def infer(self, batch_prompts, **kwargs):
return [[inference.ScoredOutput(score=1.0, output="test")]]

def infer_batch(self, prompts, batch_size=32):
return self.infer(prompts)

test_models = ["llama3", "gemma2:2b", "mistral:7b", "qwen3:0.6b"]

for model_id in test_models:
with self.subTest(model_id=model_id):
with mock.patch.dict(os.environ, {}, clear=True):
config = factory.ModelConfig(model_id=model_id)
model = factory.create_model(config)
self.assertIsInstance(model, FakeOllamaProvider)
self.assertEqual(model.model_id, model_id)

def test_model_config_fields_are_immutable(self):
"""ModelConfig fields should not be modifiable after creation."""
config = factory.ModelConfig(
Expand Down