Skip to content

community[minor]: Add native async support to SQLChatMessageHistory #22065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
31dc9c0
Real async version of SQLChatMessageHistory
pprados May 23, 2024
1eea86d
Fix spelling
pprados May 23, 2024
a19ca59
Fix spelling
pprados May 23, 2024
6a7a1b0
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 23, 2024
bab80a6
Fix dependencies
pprados May 23, 2024
8dfe16d
Fix revue
pprados May 23, 2024
35b7068
Fix poetry
pprados May 23, 2024
2e008a1
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 23, 2024
862ebcb
Fix format
pprados May 23, 2024
a2e5d3f
Merge remote-tracking branch 'origin/pprados/fix_sql_chat_history' in…
pprados May 23, 2024
df1f50a
Fix format
pprados May 23, 2024
9126f8c
Fix format
pprados May 23, 2024
ba28b80
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 23, 2024
17245b5
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 23, 2024
601a480
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 24, 2024
3a51c8e
Add @deprecated
pprados May 24, 2024
bd47d2f
Merge remote-tracking branch 'origin/pprados/fix_sql_chat_history' in…
pprados May 24, 2024
44ddc44
Fix @deprecated
pprados May 24, 2024
769608d
Fix @deprecated
pprados May 24, 2024
f74496d
Fix format
pprados May 24, 2024
14a4bc8
Fix format
pprados May 24, 2024
e858162
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 24, 2024
731c82f
Only once warning
pprados May 25, 2024
1391f54
Merge remote-tracking branch 'origin/pprados/fix_sql_chat_history' in…
pprados May 25, 2024
75bc80e
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 25, 2024
53a6abd
Only once warning
pprados May 25, 2024
39a983a
Merge remote-tracking branch 'origin/pprados/fix_sql_chat_history' in…
pprados May 25, 2024
6f1e932
Only once warning
pprados May 25, 2024
8f3c6a1
Warning when create session
pprados May 25, 2024
bb345ba
Warning when create session
pprados May 25, 2024
b6591ae
Fix typing
pprados May 25, 2024
31b6c56
Remove duplicate assertion
pprados May 27, 2024
1df1620
Fix base
pprados May 27, 2024
fafc598
Add get_messages()
pprados May 28, 2024
fc98107
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 29, 2024
4d153f5
Merge branch 'master' into pprados/fix_sql_chat_history
pprados May 31, 2024
95a097e
Merge branch 'master' into pprados/fix_sql_chat_history
eyurtsev Jun 5, 2024
f991c99
x
eyurtsev Jun 5, 2024
df1d62e
Revert changes to test dependencies. aiosqlite should be optional
eyurtsev Jun 5, 2024
68df505
relock
eyurtsev Jun 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 188 additions & 11 deletions libs/community/langchain_community/chat_message_histories/sql.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import asyncio
import contextlib
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Sequence,
Union,
cast,
)

from sqlalchemy import Column, Integer, Text, create_engine
from langchain_core._api import deprecated, warn_deprecated
from sqlalchemy import Column, Integer, Text, delete, select

try:
from sqlalchemy.orm import declarative_base
Expand All @@ -15,7 +28,22 @@
message_to_dict,
messages_from_dict,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import (
Session as SQLSession,
)
from sqlalchemy.orm import (
declarative_base,
scoped_session,
sessionmaker,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,36 +108,98 @@ def get_sql_model_class(self) -> Any:
return self.model_class


DBConnection = Union[AsyncEngine, Engine, str]

_warned_once_already = False


class SQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in an SQL database."""

@property
@deprecated("0.2.2", removal="0.3.0", alternative="session_maker")
def Session(self) -> Union[scoped_session, async_sessionmaker]:
return self.session_maker

def __init__(
self,
session_id: str,
connection_string: str,
connection_string: Optional[str] = None,
table_name: str = "message_store",
session_id_field_name: str = "session_id",
custom_message_converter: Optional[BaseMessageConverter] = None,
connection: Union[None, DBConnection] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unrelated to this PR)


@pprados this pattern could be useful if you needed a transaction (e.g., as was the case in the indexing API if I recall correctly)

Allows passing a session maker as part of the init instead of initializing and then mutating a class attribute (the latter pattern is more likely to break in threaded scenarios)

engine_args: Optional[Dict[str, Any]] = None,
async_mode: Optional[bool] = None, # Use only if connection is a string
):
self.connection_string = connection_string
self.engine = create_engine(connection_string, echo=False)
assert not (
connection_string and connection
), "connection_string and connection are mutually exclusive"
if connection_string:
global _warned_once_already
if not _warned_once_already:
warn_deprecated(
since="0.2.2",
removal="0.3.0",
name="connection_string",
alternative="Use connection instead",
)
_warned_once_already = True
connection = connection_string
self.connection_string = connection_string
if isinstance(connection, str):
self.async_mode = async_mode
if async_mode:
self.async_engine = create_async_engine(
connection, **(engine_args or {})
)
else:
self.engine = create_engine(url=connection, **(engine_args or {}))
elif isinstance(connection, Engine):
self.async_mode = False
self.engine = connection
elif isinstance(connection, AsyncEngine):
self.async_mode = True
self.async_engine = connection
else:
raise ValueError(
"connection should be a connection string or an instance of "
"sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine"
)

# To be consistent with others SQL implementations, rename to session_maker
self.session_maker: Union[scoped_session, async_sessionmaker]
if self.async_mode:
self.session_maker = async_sessionmaker(bind=self.async_engine)
else:
self.session_maker = scoped_session(sessionmaker(bind=self.engine))

self.session_id_field_name = session_id_field_name
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
self.sql_model_class = self.converter.get_sql_model_class()
if not hasattr(self.sql_model_class, session_id_field_name):
raise ValueError("SQL model class must have session_id column")
self._create_table_if_not_exists()
self._table_created = False
if not self.async_mode:
self._create_table_if_not_exists()

self.session_id = session_id
self.Session = sessionmaker(self.engine)

def _create_table_if_not_exists(self) -> None:
self.sql_model_class.metadata.create_all(self.engine)
self._table_created = True

async def _acreate_table_if_not_exists(self) -> None:
if not self._table_created:
assert self.async_mode, "This method must be called with async_mode"
async with self.async_engine.begin() as conn:
await conn.run_sync(self.sql_model_class.metadata.create_all)
self._table_created = True

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve all messages from db"""
with self.Session() as session:
with self._make_sync_session() as session:
result = (
session.query(self.sql_model_class)
.where(
Expand All @@ -123,18 +213,105 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages.append(self.converter.from_sql_model(record))
return messages

def get_messages(self) -> List[BaseMessage]:
return self.messages

async def aget_messages(self) -> List[BaseMessage]:
"""Retrieve all messages from db"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
stmt = (
select(self.sql_model_class)
.where(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
.order_by(self.sql_model_class.id.asc())
)
result = await session.execute(stmt)
messages = []
for record in result.scalars():
messages.append(self.converter.from_sql_model(record))
return messages

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
with self._make_sync_session() as session:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()

async def aadd_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.

Args:
message: A BaseMessage object to store.
"""
await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
session.add(self.converter.to_sql_model(message, self.session_id))
await session.commit()

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
# The method RunnableWithMessageHistory._exit_history() call
# add_message method by mistake and not aadd_message.
# See https://github.com/langchain-ai/langchain/issues/22021
if self.async_mode:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aadd_messages(messages))
else:
with self._make_sync_session() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
# Add all messages in one transaction
await self._acreate_table_if_not_exists()
async with self.session_maker() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
await session.commit()

def clear(self) -> None:
"""Clear session memory from db"""

with self.Session() as session:
with self._make_sync_session() as session:
session.query(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
).delete()
session.commit()

async def aclear(self) -> None:
"""Clear session memory from db"""

await self._acreate_table_if_not_exists()
async with self._make_async_session() as session:
stmt = delete(self.sql_model_class).filter(
getattr(self.sql_model_class, self.session_id_field_name)
== self.session_id
)
await session.execute(stmt)
await session.commit()

@contextlib.contextmanager
def _make_sync_session(self) -> Generator[SQLSession, None, None]:
"""Make an async session."""
if self.async_mode:
raise ValueError(
"Attempting to use a sync method in when async mode is turned on. "
"Please use the corresponding async method instead."
)
with self.session_maker() as session:
yield cast(SQLSession, session)

@contextlib.asynccontextmanager
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Make an async session."""
if not self.async_mode:
raise ValueError(
"Attempting to use an async method in when sync mode is turned on. "
"Please use the corresponding async method instead."
)
async with self.session_maker() as session:
yield cast(AsyncSession, session)
21 changes: 8 additions & 13 deletions libs/community/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ extended_testing = [
"pyjwt",
"oracledb",
"simsimd",
"aiosqlite"
]

[tool.ruff]
Expand Down
Loading
Loading