diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dc166fa9..e3bacb8a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -133,6 +133,48 @@ jobs: fi tox -e live-api + plugin-integration-test: + needs: test + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + + steps: + - uses: actions/checkout@v4 + + - name: Detect provider-related changes + id: provider-changes + uses: tj-actions/changed-files@v46 + with: + files: | + langextract/providers/** + langextract/factory.py + langextract/inference.py + tests/provider_plugin_test.py + pyproject.toml + .github/workflows/ci.yaml + + - name: Skip if no provider changes + if: steps.provider-changes.outputs.any_changed == 'false' + run: | + echo "No provider-related changes detected – skipping plugin integration test." + exit 0 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + + - name: Run plugin smoke test + run: tox -e plugin-smoke + + - name: Run plugin integration test + run: tox -e plugin-integration + ollama-integration-test: needs: test runs-on: ubuntu-latest diff --git a/langextract/providers/__init__.py b/langextract/providers/__init__.py index c0d46003..85631f48 100644 --- a/langextract/providers/__init__.py +++ b/langextract/providers/__init__.py @@ -69,27 +69,42 @@ def load_plugins_once() -> None: _PLUGINS_LOADED = True return - _PLUGINS_LOADED = True - try: entry_points_group = metadata.entry_points(group="langextract.providers") except Exception as exc: logging.debug("No third-party provider entry points found: %s", exc) return + # Set flag after successful entry point query to avoid disabling discovery + # on transient failures during enumeration. + _PLUGINS_LOADED = True + for entry_point in entry_points_group: try: provider = entry_point.load() - + # Validate provider subclasses but don't auto-register - plugins must + # use their own @register decorators to control patterns. if isinstance(provider, type): - registry.register(entry_point.name)(provider) - logging.info( - "Registered third-party provider from entry point: %s", - entry_point.name, - ) + # pylint: disable=import-outside-toplevel + # Late import to avoid circular dependency + from langextract import inference + + if issubclass(provider, inference.BaseLanguageModel): + logging.info( + "Loaded third-party provider class from entry point: %s", + entry_point.name, + ) + else: + logging.warning( + "Entry point %s returned non-provider class %r; ignoring", + entry_point.name, + provider, + ) else: + # Module import triggers decorator execution logging.debug( - "Loaded provider module from entry point: %s", entry_point.name + "Loaded provider module/object from entry point: %s", + entry_point.name, ) except Exception as exc: logging.warning( diff --git a/pyproject.toml b/pyproject.toml index 47c8a707..3831c844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,8 @@ python_functions = "test_*" addopts = "-ra" markers = [ "live_api: marks tests as requiring live API access", + "requires_pip: marks tests that perform pip install/uninstall operations", + "integration: marks integration tests that test multiple components together", ] [tool.pyink] diff --git a/tests/provider_plugin_test.py b/tests/provider_plugin_test.py new file mode 100644 index 00000000..db6bed9c --- /dev/null +++ b/tests/provider_plugin_test.py @@ -0,0 +1,408 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for provider plugin system.""" + +from importlib import metadata +import os +from pathlib import Path +import subprocess +import sys +import tempfile +import textwrap +import types +from unittest import mock +import uuid + +from absl.testing import absltest +import pytest + +import langextract as lx + + +class PluginSmokeTest(absltest.TestCase): + """Basic smoke tests for plugin loading functionality.""" + + def setUp(self): + super().setUp() + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + self.addCleanup(lx.providers.registry.clear) + self.addCleanup(setattr, lx.providers, "_PLUGINS_LOADED", False) + + def test_plugin_discovery_and_usage(self): + """Test plugin discovery via entry points. + + Entry points can return a class or module. Registration happens via + the @register decorator in both cases. + """ + + def _ep_load(): + @lx.providers.registry.register(r"^plugin-model") + class PluginProvider(lx.inference.BaseLanguageModel): + + def __init__(self, model_id=None, **kwargs): + super().__init__() + self.model_id = model_id + + def infer(self, batch_prompts, **kwargs): + return [[lx.inference.ScoredOutput(score=1.0, output="ok")]] + + return PluginProvider + + ep = types.SimpleNamespace( + name="plugin_provider", + group="langextract.providers", + value="my_pkg:PluginProvider", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + resolved_cls = lx.providers.registry.resolve("plugin-model-123") + self.assertEqual( + resolved_cls.__name__, + "PluginProvider", + "Provider should be resolvable after plugin load", + ) + + cfg = lx.factory.ModelConfig(model_id="plugin-model-123") + model = lx.factory.create_model(cfg) + + out = model.infer(["hi"])[0][0].output + self.assertEqual(out, "ok", "Provider should return expected output") + + def test_plugin_disabled_by_env_var(self): + """Test that LANGEXTRACT_DISABLE_PLUGINS=1 prevents plugin loading.""" + + with mock.patch.dict("os.environ", {"LANGEXTRACT_DISABLE_PLUGINS": "1"}): + with mock.patch.object(metadata, "entry_points") as mock_ep: + lx.providers.load_plugins_once() + mock_ep.assert_not_called() + + def test_handles_import_errors_gracefully(self): + """Test that import errors during plugin loading don't crash.""" + + def _bad_load(): + raise ImportError("Plugin not found") + + bad_ep = types.SimpleNamespace( + name="bad_plugin", + group="langextract.providers", + value="bad_pkg:BadProvider", + load=_bad_load, + ) + + with mock.patch.object(metadata, "entry_points", return_value=[bad_ep]): + lx.providers.load_plugins_once() + + providers = lx.providers.registry.list_providers() + self.assertIsInstance( + providers, + list, + "Registry should remain functional after import error", + ) + self.assertEqual( + len(providers), + 0, + "Broken EP should not partially register", + ) + + def test_load_plugins_once_is_idempotent(self): + """Test that load_plugins_once only discovers once.""" + + def _ep_load(): + @lx.providers.registry.register(r"^plugin-model") + class Plugin(lx.inference.BaseLanguageModel): + + def infer(self, *a, **k): + return [[lx.inference.ScoredOutput(score=1.0, output="ok")]] + + return Plugin + + ep = types.SimpleNamespace( + name="plugin_provider", + group="langextract.providers", + value="pkg:Plugin", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ) as m: + lx.providers.load_plugins_once() + lx.providers.load_plugins_once() # should be a no-op + self.assertEqual(m.call_count, 1, "Discovery should happen only once") + + def test_non_subclass_entry_point_does_not_crash(self): + """Test that non-BaseLanguageModel classes don't crash the system.""" + + class NotAProvider: # pylint: disable=too-few-public-methods + """Dummy class to test non-provider handling.""" + + bad_ep = types.SimpleNamespace( + name="bad", + group="langextract.providers", + value="bad:NotAProvider", + load=lambda: NotAProvider, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [bad_ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + # The system should remain functional even if a bad provider is loaded + # Trying to use it would fail, but discovery shouldn't crash + providers = lx.providers.registry.list_providers() + self.assertIsInstance( + providers, + list, + "Registry should remain functional with bad provider", + ) + with self.assertRaisesRegex(ValueError, "No provider registered"): + lx.providers.registry.resolve("bad") + + def test_plugin_priority_override_core_provider(self): + """Plugin with higher priority should override core provider on conflicts.""" + + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + + def _ep_load(): + @lx.providers.registry.register(r"^gemini", priority=50) + class OverrideGemini(lx.inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + return [[lx.inference.ScoredOutput(score=1.0, output="override")]] + + return OverrideGemini + + ep = types.SimpleNamespace( + name="override_gemini", + group="langextract.providers", + value="pkg:OverrideGemini", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + # Core gemini registers with priority 10 in providers.gemini + # Our plugin registered with priority 50; it should win. + resolved = lx.providers.registry.resolve("gemini-2.5-flash") + self.assertEqual(resolved.__name__, "OverrideGemini") + + def test_resolve_provider_for_plugin(self): + """resolve_provider should find plugin by class name and name-insensitive.""" + + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + + def _ep_load(): + @lx.providers.registry.register(r"^plugin-resolve") + class ResolveMePlease(lx.inference.BaseLanguageModel): + + def infer(self, batch_prompts, **kwargs): + return [[lx.inference.ScoredOutput(score=1.0, output="ok")]] + + return ResolveMePlease + + ep = types.SimpleNamespace( + name="resolver_plugin", + group="langextract.providers", + value="pkg:ResolveMePlease", + load=_ep_load, + ) + + with mock.patch.object( + metadata, + "entry_points", + side_effect=lambda **kw: [ep] + if kw.get("group") == "langextract.providers" + else [], + ): + lx.providers.load_plugins_once() + + cls_by_exact = lx.providers.registry.resolve_provider("ResolveMePlease") + self.assertEqual(cls_by_exact.__name__, "ResolveMePlease") + + cls_by_partial = lx.providers.registry.resolve_provider("resolveme") + self.assertEqual(cls_by_partial.__name__, "ResolveMePlease") + + +class PluginE2ETest(absltest.TestCase): + """End-to-end test with actual pip installation. + + This test is expensive and only runs when explicitly requested + via tox -e plugin-e2e or in CI when provider files change. + """ + + @pytest.mark.requires_pip + @pytest.mark.integration + def test_pip_install_discovery_and_cleanup(self): + """Test complete plugin lifecycle: install, discovery, usage, uninstall. + + This test: + 1. Creates a Python package with a provider plugin + 2. Installs it via pip + 3. Verifies the plugin is discovered and usable + 4. Uninstalls and verifies cleanup + """ + + with tempfile.TemporaryDirectory() as tmpdir: + pkg_name = f"test_langextract_plugin_{uuid.uuid4().hex[:8]}" + pkg_dir = Path(tmpdir) / pkg_name + pkg_dir.mkdir() + + (pkg_dir / pkg_name).mkdir() + (pkg_dir / pkg_name / "__init__.py").write_text("") + + (pkg_dir / pkg_name / "provider.py").write_text(textwrap.dedent(""" + import langextract as lx + + USED_BY_EXTRACT = False + + @lx.providers.registry.register(r'^test-pip-model', priority=50) + class TestPipProvider(lx.inference.BaseLanguageModel): + def __init__(self, model_id, **kwargs): + super().__init__() + self.model_id = model_id + + def infer(self, batch_prompts, **kwargs): + global USED_BY_EXTRACT + USED_BY_EXTRACT = True + return [[lx.inference.ScoredOutput(score=1.0, output="pip test response")]] + """)) + + (pkg_dir / "pyproject.toml").write_text(textwrap.dedent(f""" + [build-system] + requires = ["setuptools>=61.0"] + build-backend = "setuptools.build_meta" + + [project] + name = "{pkg_name}" + version = "0.0.1" + description = "Test plugin for langextract" + + [project.entry-points."langextract.providers"] + test_provider = "{pkg_name}.provider:TestPipProvider" + """)) + + pip_env = { + **os.environ, + "PIP_NO_INPUT": "1", + "PIP_DISABLE_PIP_VERSION_CHECK": "1", + } + result = subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "-e", + str(pkg_dir), + "--no-deps", + "-q", + ], + check=True, + capture_output=True, + text=True, + env=pip_env, + ) + + self.assertEqual(result.returncode, 0, "pip install failed") + self.assertNotIn( + "ERROR", + result.stderr.upper(), + f"pip install had errors: {result.stderr}", + ) + + try: + test_script = Path(tmpdir) / "test_plugin.py" + test_script.write_text(textwrap.dedent(f""" + import langextract as lx + import sys + + lx.providers.load_plugins_once() + + # Test via factory.create_model + cfg = lx.factory.ModelConfig(model_id="test-pip-model-123") + model = lx.factory.create_model(cfg) + result = model.infer(["test prompt"]) + assert result[0][0].output == "pip test response", f"Got: {{result[0][0].output}}" + + # Verify the plugin is resolvable via the registry + resolved = lx.providers.registry.resolve("test-pip-model-xyz") + assert resolved.__name__ == "TestPipProvider", "Plugin should be resolvable" + + from {pkg_name}.provider import USED_BY_EXTRACT + assert USED_BY_EXTRACT, "Provider infer() was not called" + + print("SUCCESS: Plugin test passed") + """)) + + result = subprocess.run( + [sys.executable, str(test_script)], + capture_output=True, + text=True, + check=False, + ) + + self.assertIn( + "SUCCESS", + result.stdout, + f"Test failed. stdout: {result.stdout}, stderr: {result.stderr}", + ) + + finally: + subprocess.run( + [sys.executable, "-m", "pip", "uninstall", "-y", pkg_name], + check=False, + capture_output=True, + env=pip_env, + ) + + lx.providers.registry.clear() + lx.providers._PLUGINS_LOADED = False + lx.providers.load_plugins_once() + + with self.assertRaisesRegex( + ValueError, "No provider registered for model_id='test-pip-model" + ): + lx.providers.registry.resolve("test-pip-model-789") + + +if __name__ == "__main__": + absltest.main() diff --git a/tox.ini b/tox.ini index 70cdcf12..d9523ff2 100644 --- a/tox.ini +++ b/tox.ini @@ -22,7 +22,7 @@ setenv = deps = .[openai,dev,test] commands = - pytest -ra -m "not live_api" + pytest -ra -m "not live_api and not requires_pip" [testenv:format] skip_install = true @@ -62,3 +62,20 @@ deps = requests>=2.25.0 commands = pytest tests/test_ollama_integration.py -v --tb=short + +[testenv:plugin-integration] +basepython = python3.11 +setenv = + PIP_NO_INPUT = 1 + PIP_DISABLE_PIP_VERSION_CHECK = 1 +deps = + .[dev,test] +commands = + pytest tests/provider_plugin_test.py::PluginE2ETest -v -m "requires_pip" + +[testenv:plugin-smoke] +basepython = python3.11 +deps = + .[dev,test] +commands = + pytest tests/provider_plugin_test.py::PluginSmokeTest -v