From 4dfdfbea1e8719d8d15776b7ebe6fd2789dc4175 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 27 May 2024 12:00:41 +0200 Subject: [PATCH] Rebase from langchain v0.2.1 --- .../langchain_community/docstore/__init__.py | 6 +- .../docstore/sql_docstore.py | 264 ++++++++++++++++++ libs/community/poetry.lock | 9 +- libs/community/pyproject.toml | 1 + .../unit_tests/docstore/test_sqlstore.py | 89 ++++++ 5 files changed, 363 insertions(+), 6 deletions(-) create mode 100644 libs/community/langchain_community/docstore/sql_docstore.py create mode 100644 libs/community/tests/unit_tests/docstore/test_sqlstore.py diff --git a/libs/community/langchain_community/docstore/__init__.py b/libs/community/langchain_community/docstore/__init__.py index ea98f747acc95..b7f1911a0353e 100644 --- a/libs/community/langchain_community/docstore/__init__.py +++ b/libs/community/langchain_community/docstore/__init__.py @@ -25,6 +25,9 @@ from langchain_community.docstore.in_memory import ( InMemoryDocstore, ) + from langchain_community.docstore.sql_docstore import ( + SQLStore, + ) from langchain_community.docstore.wikipedia import ( Wikipedia, ) @@ -32,6 +35,7 @@ _module_lookup = { "DocstoreFn": "langchain_community.docstore.arbitrary_fn", "InMemoryDocstore": "langchain_community.docstore.in_memory", + "SQLStore": "langchain_community.docstore.sql_docstore", "Wikipedia": "langchain_community.docstore.wikipedia", } @@ -43,4 +47,4 @@ def __getattr__(name: str) -> Any: raise AttributeError(f"module {__name__} has no attribute {name}") -__all__ = ["DocstoreFn", "InMemoryDocstore", "Wikipedia"] +__all__ = ["DocstoreFn", "InMemoryDocstore", "SQLStore", "Wikipedia"] diff --git a/libs/community/langchain_community/docstore/sql_docstore.py b/libs/community/langchain_community/docstore/sql_docstore.py new file mode 100644 index 0000000000000..40446a4e97907 --- /dev/null +++ b/libs/community/langchain_community/docstore/sql_docstore.py @@ -0,0 +1,264 @@ +import contextlib +from pathlib import Path +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Generator, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from langchain_core.stores import BaseStore, V +from sqlalchemy import ( + Column, + Engine, + PickleType, + and_, + create_engine, + delete, + select, +) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Mapped, + Session, + declarative_base, + mapped_column, + sessionmaker, +) + +Base = declarative_base() + + +def items_equal(x: Any, y: Any) -> bool: + return x == y + + +class Value(Base): # type: ignore[valid-type,misc] + """Table used to save values.""" + + # ATTENTION: + # Prior to modifying this table, please determine whether + # we should create migrations for this table to make sure + # users do not experience data loss. + __tablename__ = "docstore" + + namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + # value: Mapped[Any] = Column(type_=PickleType, index=False, nullable=False) + value: Any = Column("earthquake", PickleType(comparator=items_equal)) + + +# This is a fix of original SQLStore. +# This can will be removed when a PR will be merged. +class SQLStore(BaseStore[str, bytes]): + """BaseStore interface that works on an SQL database. + + Examples: + Create a SQLStore instance and perform operations on it: + + .. code-block:: python + + from langchain_rag.storage import SQLStore + + # Instantiate the SQLStore with the root path + sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:") + + # Set values for keys + sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + for key in sql_store.yield_keys(): + print(key) + + """ + + def __init__( + self, + *, + namespace: str, + db_url: Optional[Union[str, Path]] = None, + engine: Optional[Union[Engine, AsyncEngine]] = None, + engine_kwargs: Optional[Dict[str, Any]] = None, + async_mode: Optional[bool] = None, + ): + if db_url is None and engine is None: + raise ValueError("Must specify either db_url or engine") + + if db_url is not None and engine is not None: + raise ValueError("Must specify either db_url or engine, not both") + + _engine: Union[Engine, AsyncEngine] + if db_url: + if async_mode is None: + async_mode = False + if async_mode: + _engine = create_async_engine( + url=str(db_url), + **(engine_kwargs or {}), + ) + else: + _engine = create_engine(url=str(db_url), **(engine_kwargs or {})) + elif engine: + _engine = engine + + else: + raise AssertionError("Something went wrong with configuration of engine.") + + _session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] + if isinstance(_engine, AsyncEngine): + self.async_mode = True + _session_maker = async_sessionmaker(bind=_engine) + else: + self.async_mode = False + _session_maker = sessionmaker(bind=_engine) + + self.engine = _engine + self.dialect = _engine.dialect.name + self.session_maker = _session_maker + self.namespace = namespace + + def create_schema(self) -> None: + Base.metadata.create_all(self.engine) + + async def acreate_schema(self) -> None: + assert isinstance(self.engine, AsyncEngine) + async with self.engine.begin() as session: + await session.run_sync(Base.metadata.create_all) + + def drop(self) -> None: + Base.metadata.drop_all(bind=self.engine.connect()) + + async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: + assert isinstance(self.engine, AsyncEngine) + result: Dict[str, V] = {} + async with self._make_async_session() as session: + stmt = select(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + for v in await session.scalars(stmt): + result[v.key] = v.value + return [result.get(key) for key in keys] + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + result = {} + + with self._make_sync_session() as session: + stmt = select(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + for v in session.scalars(stmt): + result[v.key] = v.value + return [result.get(key) for key in keys] + + async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + async with self._make_async_session() as session: + await self._amdelete([key for key, _ in key_value_pairs], session) + session.add_all( + [ + Value(namespace=self.namespace, key=k, value=v) + for k, v in key_value_pairs + ] + ) + await session.commit() + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + values: Dict[str, bytes] = dict(key_value_pairs) + with self._make_sync_session() as session: + self._mdelete(list(values.keys()), session) + session.add_all( + [ + Value(namespace=self.namespace, key=k, value=v) + for k, v in values.items() + ] + ) + session.commit() + + def _mdelete(self, keys: Sequence[str], session: Session) -> None: + stmt = delete(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + session.execute(stmt) + + async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None: + stmt = delete(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + await session.execute(stmt) + + def mdelete(self, keys: Sequence[str]) -> None: + with self._make_sync_session() as session: + self._mdelete(keys, session) + session.commit() + + async def amdelete(self, keys: Sequence[str]) -> None: + async with self._make_async_session() as session: + await self._amdelete(keys, session) + await session.commit() + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + with self._make_sync_session() as session: + for v in session.query(Value).filter( # type: ignore + Value.namespace == self.namespace + ): + yield str(v.key) + session.close() + + async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: + async with self._make_async_session() as session: + stmt = select(Value).filter(Value.namespace == self.namespace) + for v in await session.scalars(stmt): + yield str(v.key) + await session.close() + + @contextlib.contextmanager + def _make_sync_session(self) -> Generator[Session, None, None]: + """Make an async session.""" + if self.async_mode: + raise ValueError( + "Attempting to use a sync method in when async mode is turned on. " + "Please use the corresponding async method instead." + ) + with cast(Session, self.session_maker()) as session: + yield cast(Session, session) + + @contextlib.asynccontextmanager + async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: + """Make an async session.""" + if not self.async_mode: + raise ValueError( + "Attempting to use an async method in when sync mode is turned on. " + "Please use the corresponding async method instead." + ) + async with cast(AsyncSession, self.session_maker()) as session: + yield cast(AsyncSession, session) diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index 727ac1f605c71..a8911e588c5ed 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -154,7 +154,7 @@ frozenlist = ">=1.1.0" name = "aiosqlite" version = "0.19.0" description = "asyncio bridge to the standard sqlite3 module" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"}, @@ -3981,7 +3981,7 @@ files = [ [[package]] name = "langchain" -version = "0.2.0" +version = "0.2.1" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -3991,7 +3991,6 @@ develop = true [package.dependencies] aiohttp = "^3.8.3" async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} -dataclasses-json = ">= 0.5.7, < 0.7" langchain-core = "^0.2.0" langchain-text-splitters = "^0.2.0" langsmith = "^0.1.17" @@ -4023,7 +4022,7 @@ url = "../langchain" [[package]] name = "langchain-core" -version = "0.2.0" +version = "0.2.2rc1" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -10079,4 +10078,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "1a30f88ba6352cfd5af8d3b7b6418ec01bad42c73e6096b1a3ccef06cb36709b" +content-hash = "dd05d83fbb835b7c420c372b3705b0d9d94bbe853042ee6cc486e7a51a9428b4" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 4955db47647fe..69386fec1d224 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -128,6 +128,7 @@ pytest-mock = "^3.10.0" pytest-socket = "^0.6.0" syrupy = "^4.0.2" requests-mock = "^1.11.0" +aiosqlite = "^0.19.0" langchain-core = {path = "../core", develop = true} langchain = {path = "../langchain", develop = true} diff --git a/libs/community/tests/unit_tests/docstore/test_sqlstore.py b/libs/community/tests/unit_tests/docstore/test_sqlstore.py new file mode 100644 index 0000000000000..39b9bb9fa4be5 --- /dev/null +++ b/libs/community/tests/unit_tests/docstore/test_sqlstore.py @@ -0,0 +1,89 @@ +from typing import AsyncGenerator, Generator, cast + +import pytest +from langchain.storage._lc_store import create_kv_docstore, create_lc_store +from langchain_core.documents import Document +from langchain_core.stores import BaseStore + +from langchain_community.docstore import SQLStore + + +@pytest.fixture +def sql_store() -> Generator[SQLStore, None, None]: + store = SQLStore(namespace="test", db_url="sqlite://") + store.create_schema() + yield store + + +@pytest.fixture +async def async_sql_store() -> AsyncGenerator[SQLStore, None]: + store = SQLStore(namespace="test", db_url="sqlite+aiosqlite://", async_mode=True) + await store.acreate_schema() + yield store + + +def test_create_lc_store(sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore: BaseStore[str, Document] = cast( + BaseStore[str, Document], create_lc_store(sql_store) + ) + docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) + fetched_doc = docstore.mget(["key1"])[0] + assert fetched_doc is not None + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +def test_create_kv_store(sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore = create_kv_docstore(sql_store) + docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) + fetched_doc = docstore.mget(["key1"])[0] + assert isinstance(fetched_doc, Document) + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +@pytest.mark.asyncio +async def test_async_create_kv_store(async_sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore = create_kv_docstore(async_sql_store) + await docstore.amset( + [("key1", Document(page_content="hello", metadata={"key": "value"}))] + ) + fetched_doc = (await docstore.amget(["key1"]))[0] + assert isinstance(fetched_doc, Document) + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +def test_sample_sql_docstore(sql_store: SQLStore) -> None: + # Set values for keys + sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + assert values == [b"value1", b"value2"] + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + assert [key for key in sql_store.yield_keys()] == ["key2"] + + +@pytest.mark.asyncio +async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None: + # Set values for keys + await async_sql_store.amset([("key1", b"value1"), ("key2", b"value2")]) + # sql_store.mset([("key1", "value1"), ("key2", "value2")]) + + # Get values for keys + values = await async_sql_store.amget( + ["key1", "key2"] + ) # Returns [b"value1", b"value2"] + assert values == [b"value1", b"value2"] + # Delete keys + await async_sql_store.amdelete(["key1"]) + + # Iterate over keys + assert [key async for key in async_sql_store.ayield_keys()] == ["key2"]