From 47905af8fce2d6930f1c61d649705491ade46f8a Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Fri, 19 Jan 2024 16:00:02 +0100 Subject: [PATCH 01/13] Add SQL Docstore --- .../storage/sql_docstore.py | 236 ++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 libs/community/langchain_community/storage/sql_docstore.py diff --git a/libs/community/langchain_community/storage/sql_docstore.py b/libs/community/langchain_community/storage/sql_docstore.py new file mode 100644 index 0000000000000..5a1adfffa0fd3 --- /dev/null +++ b/libs/community/langchain_community/storage/sql_docstore.py @@ -0,0 +1,236 @@ +import contextlib +from pathlib import Path +from typing import ( + Any, + Dict, + Generator, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, +) + +from langchain_core.stores import BaseStore +from sqlalchemy import ( + Column, + Engine, + PickleType, + and_, + create_engine, +) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Mapped, + Session, + declarative_base, + mapped_column, + scoped_session, + sessionmaker, +) + +Base = declarative_base() + + +def items_equal(x: Any, y: Any) -> bool: + return x == y + + +class Value(Base): # type: ignore[valid-type,misc] + """Table used to save values.""" + + # ATTENTION: + # Prior to modifying this table, please determine whether + # we should create migrations for this table to make sure + # users do not experience data loss. + __tablename__ = "docstore" + + namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + # value: Mapped[Any] = Column(type_=PickleType, index=False, nullable=False) + value: Any = Column("earthquake", PickleType(comparator=items_equal)) + + +# This is a fix of original SQLStore. +# This can will be removed when a PR will be merged. +class SQLStore(BaseStore[str, bytes]): + """BaseStore interface that works on an SQL database. + + Examples: + Create a SQLStore instance and perform operations on it: + + .. code-block:: python + + from langchain_rag.storage import SQLStore + + # Instantiate the SQLStore with the root path + sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:") + + # Set values for keys + sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + for key in sql_store.yield_keys(): + print(key) + + """ + + def __init__( + self, + *, + namespace: str, + db_url: Optional[Union[str, Path]] = None, + engine: Optional[Union[Engine, AsyncEngine]] = None, + engine_kwargs: Optional[Dict[str, Any]] = None, + async_mode: bool = False, + ): + if db_url is None and engine is None: + raise ValueError("Must specify either db_url or engine") + + if db_url is not None and engine is not None: + raise ValueError("Must specify either db_url or engine, not both") + + _engine: Union[Engine, AsyncEngine] + if db_url: + if async_mode: + _engine = create_async_engine(url=str(db_url), **(engine_kwargs or {})) + else: + _engine = create_engine(url=str(db_url), **(engine_kwargs or {})) + elif engine: + _engine = engine + + else: + raise AssertionError("Something went wrong with configuration of engine.") + + _session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] + if isinstance(_engine, AsyncEngine): + _session_factory = async_sessionmaker(bind=_engine) + else: + _session_factory = sessionmaker(bind=_engine) + + self.engine = _engine + self.dialect = _engine.dialect.name + self.session_factory = _session_factory + self.namespace = namespace + + def create_schema(self) -> None: + Base.metadata.create_all(self.engine) + + def drop(self) -> None: + Base.metadata.drop_all(bind=self.engine.connect()) + + # async def amget(self, keys: Sequence[K]) -> List[Optional[V]]: + # result = {} + # async with self._make_session() as session: + # async with session.begin(): + # for v in session.query(Value).filter( + # and_( + # Value.key.in_(keys), + # Value.namespace == self.namespace, + # ) + # ): + # result[v.key] = v.value + # return [result.get(key) for key in keys] + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + result = {} + + with self._make_session() as session: + for v in session.query(Value).filter( # type: ignore + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ): + result[v.key] = v.value + return [result.get(key) for key in keys] + + # async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: + # async with self._make_session() as session: + # async with session.begin(): + # # await self._amdetete([key for key, _ in key_value_pairs], session) + # session.add_all([Value(namespace=self.namespace, + # key=k, + # value=v) for k, v in key_value_pairs]) + # session.commit() + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + # try: + with self._make_session() as session: + self._mdetete([key for key, _ in key_value_pairs], session) + session.add_all( + [ + Value(namespace=self.namespace, key=k, value=v) + for k, v in key_value_pairs + ] + ) + + def _mdetete(self, keys: Sequence[str], session: Session) -> None: + with session.begin(): + session.query(Value).filter( # type: ignore + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ).delete() + + # async def _amdetete(self, keys: Sequence[str], session: Session) -> None: + # await session.query(Value).filter( + # and_( + # Value.key.in_(keys), + # Value.namespace == self.namespace, + # ) + # ).delete() + + def mdelete(self, keys: Sequence[str]) -> None: + with self._make_session() as session: + self._mdetete(keys, session) + session.commit() + + # async def amdelete(self, keys: Sequence[str]) -> None: + # with self._make_session() as session: + # await self._mdelete(keys, session) + # session.commit() + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + with self._make_session() as session: + for v in session.query(Value).filter( # type: ignore + Value.namespace == self.namespace + ): + yield str(v.key) + session.close() + + @contextlib.contextmanager + def _make_session(self) -> Generator[Session, None, None]: + """Create a session and close it after use.""" + + if isinstance(self.session_factory, async_sessionmaker): + raise AssertionError("This method is not supported for async engines.") + + session = scoped_session(self.session_factory)() + try: + yield session + finally: + session.commit() + + # @contextlib.asynccontextmanager + # async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]: + # """Create a session and close it after use.""" + # + # if not isinstance(self.session_factory, async_sessionmaker): + # raise AssertionError("This method is not supported for sync engines.") + # + # async with self.session_factory() as session: + # yield session From e80063d031e267b5ddae20b91a668c7245f4db58 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 22 Jan 2024 08:19:08 -0800 Subject: [PATCH 02/13] core[patch], community[patch], langchain[patch], docs: Update SQL chains/agents/docs (#16168) Revamp SQL use cases docs. In the process update SQL chains and agents. --- docs/docs/use_cases/sql/agents.ipynb | 815 +++++++++++++++++++ docs/docs/use_cases/sql/index.ipynb | 68 ++ docs/docs/use_cases/sql/large_db.ipynb | 627 ++++++++++++++ docs/docs/use_cases/sql/prompting.ipynb | 789 ++++++++++++++++++ docs/docs/use_cases/sql/query_checking.ipynb | 389 +++++++++ docs/docs/use_cases/sql/quickstart.ipynb | 603 ++++++++++++++ 6 files changed, 3291 insertions(+) create mode 100644 docs/docs/use_cases/sql/agents.ipynb create mode 100644 docs/docs/use_cases/sql/index.ipynb create mode 100644 docs/docs/use_cases/sql/large_db.ipynb create mode 100644 docs/docs/use_cases/sql/prompting.ipynb create mode 100644 docs/docs/use_cases/sql/query_checking.ipynb create mode 100644 docs/docs/use_cases/sql/quickstart.ipynb diff --git a/docs/docs/use_cases/sql/agents.ipynb b/docs/docs/use_cases/sql/agents.ipynb new file mode 100644 index 0000000000000..aa6db0dd3f920 --- /dev/null +++ b/docs/docs/use_cases/sql/agents.ipynb @@ -0,0 +1,815 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_position: 1\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Agents\n", + "\n", + "LangChain has a SQL Agent which provides a more flexible way of interacting with SQL Databases than a chain. The main advantages of using the SQL Agent are:\n", + "\n", + "- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n", + "- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n", + "- It can query the database as many times as needed to answer the user question.\n", + "\n", + "To initialize the agent we'll use the [create_sql_agent](https://api.python.langchain.com/en/latest/agent_toolkits/langchain_community.agent_toolkits.sql.base.create_sql_agent.html) constructor. This agent uses the `SQLDatabaseToolkit` which contains tools to: \n", + "\n", + "* Create and execute queries\n", + "* Check query syntax\n", + "* Retrieve table descriptions\n", + "* ... and more\n", + "\n", + "## Setup\n", + "\n", + "First, get required packages and set environment variables:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain langchain-community langchain-openai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", + "\n", + "# Uncomment the below to use LangSmith. Not required.\n", + "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", + "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", + "\n", + "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", + "* Run `sqlite3 Chinook.db`\n", + "* Run `.read Chinook_Sqlite.sql`\n", + "* Test `SELECT * FROM Artist LIMIT 10;`\n", + "\n", + "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sqlite\n", + "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" + ] + }, + { + "data": { + "text/plain": [ + "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.utilities import SQLDatabase\n", + "\n", + "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", + "print(db.dialect)\n", + "print(db.get_usable_table_names())\n", + "db.run(\"SELECT * FROM Artist LIMIT 10;\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Agent\n", + "\n", + "We'll use an OpenAI chat model and an `\"openai-tools\"` agent, which will use OpenAI's function-calling API to drive the agent's tool selection and invocations." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.agent_toolkits import create_sql_agent\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", + "agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_list_tables` with `{}`\n", + "\n", + "\n", + "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_schema` with `Invoice,Customer`\n", + "\n", + "\n", + "\u001b[0m\u001b[33;1m\u001b[1;3m\n", + "CREATE TABLE \"Customer\" (\n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n", + "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", + "\t\"Company\" NVARCHAR(80), \n", + "\t\"Address\" NVARCHAR(70), \n", + "\t\"City\" NVARCHAR(40), \n", + "\t\"State\" NVARCHAR(40), \n", + "\t\"Country\" NVARCHAR(40), \n", + "\t\"PostalCode\" NVARCHAR(10), \n", + "\t\"Phone\" NVARCHAR(24), \n", + "\t\"Fax\" NVARCHAR(24), \n", + "\t\"Email\" NVARCHAR(60) NOT NULL, \n", + "\t\"SupportRepId\" INTEGER, \n", + "\tPRIMARY KEY (\"CustomerId\"), \n", + "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Customer table:\n", + "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n", + "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n", + "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n", + "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Invoice\" (\n", + "\t\"InvoiceId\" INTEGER NOT NULL, \n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"InvoiceDate\" DATETIME NOT NULL, \n", + "\t\"BillingAddress\" NVARCHAR(70), \n", + "\t\"BillingCity\" NVARCHAR(40), \n", + "\t\"BillingState\" NVARCHAR(40), \n", + "\t\"BillingCountry\" NVARCHAR(40), \n", + "\t\"BillingPostalCode\" NVARCHAR(10), \n", + "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", + "\tPRIMARY KEY (\"InvoiceId\"), \n", + "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Invoice table:\n", + "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", + "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", + "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", + "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", + "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC`\n", + "responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n", + "\n", + "Here is the SQL query:\n", + "\n", + "```sql\n", + "SELECT c.Country, SUM(i.Total) AS TotalSales\n", + "FROM Invoice i\n", + "JOIN Customer c ON i.CustomerId = c.CustomerId\n", + "GROUP BY c.Country\n", + "ORDER BY TotalSales DESC\n", + "```\n", + "\n", + "Now, I will execute this query to get the results.\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n", + "\n", + "1. USA: $523.06\n", + "2. Canada: $303.96\n", + "3. France: $195.10\n", + "4. Brazil: $190.10\n", + "5. Germany: $156.48\n", + "6. United Kingdom: $112.86\n", + "7. Czech Republic: $90.24\n", + "8. Portugal: $77.24\n", + "9. India: $75.26\n", + "10. Chile: $46.62\n", + "\n", + "The country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': \"List the total sales per country. Which country's customers spent the most?\",\n", + " 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nThe country whose customers spent the most is the USA, with a total sales of $523.06.'}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke(\n", + " \"List the total sales per country. Which country's customers spent the most?\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_list_tables` with `{}`\n", + "\n", + "\n", + "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_schema` with `PlaylistTrack`\n", + "\n", + "\n", + "\u001b[0m\u001b[33;1m\u001b[1;3m\n", + "CREATE TABLE \"PlaylistTrack\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", + "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", + "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from PlaylistTrack table:\n", + "PlaylistId\tTrackId\n", + "1\t3402\n", + "1\t3389\n", + "1\t3390\n", + "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n", + "\n", + "Here is the schema of the `PlaylistTrack` table:\n", + "\n", + "```\n", + "CREATE TABLE \"PlaylistTrack\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", + "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", + "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", + ")\n", + "```\n", + "\n", + "The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n", + "\n", + "Here are three sample rows from the `PlaylistTrack` table:\n", + "\n", + "```\n", + "PlaylistId TrackId\n", + "1 3402\n", + "1 3389\n", + "1 3390\n", + "```\n", + "\n", + "Please let me know if there is anything else I can help with.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'Describe the playlisttrack table',\n", + " 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke(\"Describe the playlisttrack table\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using a dynamic few-shot prompt\n", + "\n", + "To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input.\n", + "\n", + "First we need some user input <> SQL query examples:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "examples = [\n", + " {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n", + " {\n", + " \"input\": \"Find all albums for the artist 'AC/DC'.\",\n", + " \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n", + " },\n", + " {\n", + " \"input\": \"List all tracks in the 'Rock' genre.\",\n", + " \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n", + " },\n", + " {\n", + " \"input\": \"Find the total duration of all tracks.\",\n", + " \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n", + " },\n", + " {\n", + " \"input\": \"List all customers from Canada.\",\n", + " \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n", + " },\n", + " {\n", + " \"input\": \"How many tracks are there in the album with ID 5?\",\n", + " \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n", + " },\n", + " {\n", + " \"input\": \"Find the total number of invoices.\",\n", + " \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n", + " },\n", + " {\n", + " \"input\": \"List all tracks that are longer than 5 minutes.\",\n", + " \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n", + " },\n", + " {\n", + " \"input\": \"Who are the top 5 customers by total purchase?\",\n", + " \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n", + " },\n", + " {\n", + " \"input\": \"Which albums are from the year 2000?\",\n", + " \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n", + " },\n", + " {\n", + " \"input\": \"How many employees are there\",\n", + " \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n", + " },\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can create an example selector. This will take the actual user input and select some number of examples to add to our few-shot prompt. We'll use a SemanticSimilarityExampleSelector, which will perform a semantic search using the embeddings and vector store we configure to find the examples most similar to our input:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.vectorstores import FAISS\n", + "from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n", + "from langchain_openai import OpenAIEmbeddings\n", + "\n", + "example_selector = SemanticSimilarityExampleSelector.from_examples(\n", + " examples,\n", + " OpenAIEmbeddings(),\n", + " FAISS,\n", + " k=5,\n", + " input_keys=[\"input\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can create our FewShotPromptTemplate, which takes our example selector, an example prompt for formatting each example, and a string prefix and suffix to put before and after our formatted examples:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.prompts import (\n", + " ChatPromptTemplate,\n", + " FewShotPromptTemplate,\n", + " MessagesPlaceholder,\n", + " PromptTemplate,\n", + " SystemMessagePromptTemplate,\n", + ")\n", + "\n", + "system_prefix = \"\"\"You are an agent designed to interact with a SQL database.\n", + "Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n", + "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n", + "You can order the results by a relevant column to return the most interesting examples in the database.\n", + "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", + "You have access to tools for interacting with the database.\n", + "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n", + "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n", + "\n", + "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n", + "\n", + "If the question does not seem related to the database, just return \"I don't know\" as the answer.\n", + "\n", + "Here are some examples of user inputs and their corresponding SQL queries:\"\"\"\n", + "\n", + "few_shot_prompt = FewShotPromptTemplate(\n", + " example_selector=example_selector,\n", + " example_prompt=PromptTemplate.from_template(\n", + " \"User input: {input}\\nSQL query: {query}\"\n", + " ),\n", + " input_variables=[\"input\", \"dialect\", \"top_k\"],\n", + " prefix=system_prefix,\n", + " suffix=\"\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since our underlying agent is an [OpenAI tools agent](/docs/modules/agents/agent_types/openai_tools), which uses OpenAI function calling, our full prompt should be a chat prompt with a human message template and an agent_scratchpad `MessagesPlaceholder`. The few-shot prompt will be used for our system message:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "full_prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " SystemMessagePromptTemplate(prompt=few_shot_prompt),\n", + " (\"human\", \"{input}\"),\n", + " MessagesPlaceholder(\"agent_scratchpad\"),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "System: You are an agent designed to interact with a SQL database.\n", + "Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n", + "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n", + "You can order the results by a relevant column to return the most interesting examples in the database.\n", + "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", + "You have access to tools for interacting with the database.\n", + "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n", + "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n", + "\n", + "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n", + "\n", + "If the question does not seem related to the database, just return \"I don't know\" as the answer.\n", + "\n", + "Here are some examples of user inputs and their corresponding SQL queries:\n", + "\n", + "User input: List all artists.\n", + "SQL query: SELECT * FROM Artist;\n", + "\n", + "User input: How many employees are there\n", + "SQL query: SELECT COUNT(*) FROM \"Employee\"\n", + "\n", + "User input: How many tracks are there in the album with ID 5?\n", + "SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n", + "\n", + "User input: List all tracks in the 'Rock' genre.\n", + "SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n", + "\n", + "User input: Which albums are from the year 2000?\n", + "SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n", + "Human: How many arists are there\n" + ] + } + ], + "source": [ + "# Example formatted prompt\n", + "prompt_val = full_prompt.invoke(\n", + " {\n", + " \"input\": \"How many arists are there\",\n", + " \"top_k\": 5,\n", + " \"dialect\": \"SQLite\",\n", + " \"agent_scratchpad\": [],\n", + " }\n", + ")\n", + "print(prompt_val.to_string())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now we can create our agent with our custom prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "agent = create_sql_agent(\n", + " llm=llm,\n", + " db=db,\n", + " prompt=full_prompt,\n", + " verbose=True,\n", + " agent_type=\"openai-tools\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try it out:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist'}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m[(275,)]\u001b[0m\u001b[32;1m\u001b[1;3mThere are 275 artists in the database.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'How many artists are there?',\n", + " 'output': 'There are 275 artists in the database.'}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.invoke({\"input\": \"How many artists are there?\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dealing with high-cardinality columns\n", + "\n", + "In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n", + "\n", + "We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n", + "\n", + "First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['For Those About To Rock We Salute You',\n", + " 'Balls to the Wall',\n", + " 'Restless and Wild',\n", + " 'Let There Be Rock',\n", + " 'Big Ones']" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import ast\n", + "import re\n", + "\n", + "\n", + "def query_as_list(db, query):\n", + " res = db.run(query)\n", + " res = [el for sub in ast.literal_eval(res) for el in sub if el]\n", + " res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n", + " return res\n", + "\n", + "\n", + "artists = query_as_list(db, \"SELECT Name FROM Artist\")\n", + "albums = query_as_list(db, \"SELECT Title FROM Album\")\n", + "albums[:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can proceed with creating the custom **retriever tool** and the final agent:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents.agent_toolkits import create_retriever_tool\n", + "\n", + "vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())\n", + "retriever = vector_db.as_retriever(search_kwargs={\"k\": 5})\n", + "description = \"\"\"Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \\\n", + "valid proper nouns. Use the noun most similar to the search.\"\"\"\n", + "retriever_tool = create_retriever_tool(\n", + " retriever,\n", + " name=\"search_proper_nouns\",\n", + " description=description,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "system = \"\"\"You are an agent designed to interact with a SQL database.\n", + "Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n", + "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n", + "You can order the results by a relevant column to return the most interesting examples in the database.\n", + "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", + "You have access to tools for interacting with the database.\n", + "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n", + "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n", + "\n", + "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n", + "\n", + "If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the \"search_proper_nouns\" tool!\n", + "\n", + "You have access to the following tables: {table_names}\n", + "\n", + "If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n", + "\n", + "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\"), MessagesPlaceholder(\"agent_scratchpad\")])\n", + "agent = create_sql_agent(\n", + " llm=llm,\n", + " db=db,\n", + " extra_tools=[retriever_tool],\n", + " prompt=prompt,\n", + " agent_type=\"openai-tools\",\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `search_proper_nouns` with `{'query': 'alice in chains'}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3mAlice In Chains\n", + "\n", + "Metallica\n", + "\n", + "Pearl Jam\n", + "\n", + "Pearl Jam\n", + "\n", + "Smashing Pumpkins\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_query` with `{'query': \"SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')\"}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m[(1,)]\u001b[0m\u001b[32;1m\u001b[1;3mAlice In Chains has 1 album.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'How many albums does alice in chains have?',\n", + " 'output': 'Alice In Chains has 1 album.'}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.invoke({\"input\": \"How many albums does alice in chains have?\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see, the agent used the `search_proper_nouns` tool in order to check how to correctly query the database for this specific artist." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Next steps\n", + "\n", + "Under the hood, `create_sql_agent` is just passing in SQL tools to more generic agent constructors. To learn more about the built-in generic agent types as well as how to build custom agents, head to the [Agents Modules](/docs/modules/agents/).\n", + "\n", + "The built-in `AgentExecutor` runs a simple Agent action -> Tool call -> Agent action... loop. To build more complex agent runtimes, head to the [LangGraph section](/docs/langgraph)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv", + "language": "python", + "name": "poetry-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/use_cases/sql/index.ipynb b/docs/docs/use_cases/sql/index.ipynb new file mode 100644 index 0000000000000..1b706c831b168 --- /dev/null +++ b/docs/docs/use_cases/sql/index.ipynb @@ -0,0 +1,68 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_position: 0.5\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SQL\n", + "\n", + "One of the most common types of databases that we can build Q&A systems for are SQL databases. LangChain comes with a number of built-in chains and agents that are compatible with any SQL dialect supported by SQLAlchemy (e.g., MySQL, PostgreSQL, Oracle SQL, Databricks, SQLite). They enable use cases such as:\n", + "\n", + "* Generating queries that will be run based on natural language questions,\n", + "* Creating chatbots that can answer questions based on database data,\n", + "* Building custom dashboards based on insights a user wants to analyze,\n", + "\n", + "and much more.\n", + "\n", + "## ⚠️ Security note ⚠️\n", + "\n", + "Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n", + "\n", + "![sql_usecase.png](../../../static/img/sql_usecase.png)\n", + "\n", + "## Quickstart\n", + "\n", + "Head to the **[Quickstart](/docs/use_cases/sql/quickstart)** page to get started.\n", + "\n", + "## Advanced\n", + "\n", + "Once you've familiarized yourself with the basics, you can head to the advanced guides:\n", + "\n", + "* [Agents](/docs/use_cases/sql/agents): Building agents that can interact with SQL DBs.\n", + "* [Prompting strategies](/docs/use_cases/sql/prompting): Strategies for improving SQL query generation.\n", + "* [Query validation](/docs/use_cases/sql/query_checking): How to validate SQL queries.\n", + "* [Large databases](/docs/use_cases/sql/large_db): How to interact with DBs with many tables and high-cardinality columns." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv", + "language": "python", + "name": "poetry-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/use_cases/sql/large_db.ipynb b/docs/docs/use_cases/sql/large_db.ipynb new file mode 100644 index 0000000000000..dd034037d1a4b --- /dev/null +++ b/docs/docs/use_cases/sql/large_db.ipynb @@ -0,0 +1,627 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "b2788654-3f62-4e2a-ab00-471922cc54df", + "metadata": {}, + "source": [ + "---\n", + "sidebar_position: 4\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "6751831d-9b08-434f-829b-d0052a3b119f", + "metadata": {}, + "source": [ + "# Large databases\n", + "\n", + "In order to write valid queries against a database, we need to feed the model the table names, table schemas, and feature values for it to query over. When there are many tables, columns, and/or high-cardinality columns, it becomes impossible for us to dump the full information about our database in every prompt. Instead, we must find ways to dynamically insert into the prompt only the most relevant information. Let's take a look at some techniques for doing this.\n", + "\n", + "\n", + "## Setup\n", + "\n", + "First, get required packages and set environment variables:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9675e433-e608-469e-b04e-2847479a8310", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install --upgrade --quiet langchain langchain-community langchain-openai" + ] + }, + { + "cell_type": "markdown", + "id": "4f56ff5d-b2e4-49e3-a0b4-fb99466cfedc", + "metadata": {}, + "source": [ + "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "06d8dd03-2d7b-4fef-b145-43c074eacb8b", + "metadata": {}, + "outputs": [ + { + "name": "stdin", + "output_type": "stream", + "text": [ + " ········\n" + ] + } + ], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "# os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", + "\n", + "# Uncomment the below to use LangSmith. Not required.\n", + "os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", + "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" + ] + }, + { + "cell_type": "markdown", + "id": "590ee096-db88-42af-90d4-99b8149df753", + "metadata": {}, + "source": [ + "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", + "\n", + "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", + "* Run `sqlite3 Chinook.db`\n", + "* Run `.read Chinook_Sqlite.sql`\n", + "* Test `SELECT * FROM Artist LIMIT 10;`\n", + "\n", + "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html) class:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cebd3915-f58f-4e73-8459-265630ae8cd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sqlite\n", + "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" + ] + }, + { + "data": { + "text/plain": [ + "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.utilities import SQLDatabase\n", + "\n", + "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", + "print(db.dialect)\n", + "print(db.get_usable_table_names())\n", + "db.run(\"SELECT * FROM Artist LIMIT 10;\")" + ] + }, + { + "cell_type": "markdown", + "id": "2e572e1f-99b5-46a2-9023-76d1e6256c0a", + "metadata": {}, + "source": [ + "## Many tables\n", + "\n", + "One of the main pieces of information we need to include in our prompt is the schemas of the relevant tables. When we have very many tables, we can't fit all of the schemas in a single prompt. What we can do in such cases is first extract the names of the tables related to the user input, and then include only their schemas.\n", + "\n", + "One easy and reliable way to do this is using OpenAI function-calling and Pydantic models. LangChain comes with a built-in [create_extraction_chain_pydantic](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_tools.extraction.create_extraction_chain_pydantic.html) chain that lets us do just this:" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "d8236886-c54f-4bdb-ad74-2514888628fd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Table(name='Genre'), Table(name='Artist'), Table(name='Track')]" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.chains.openai_tools import create_extraction_chain_pydantic\n", + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n", + "\n", + "\n", + "class Table(BaseModel):\n", + " \"\"\"Table in SQL database.\"\"\"\n", + "\n", + " name: str = Field(description=\"Name of table in SQL database.\")\n", + "\n", + "\n", + "table_names = \"\\n\".join(db.get_usable_table_names())\n", + "system = f\"\"\"Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \\\n", + "The tables are:\n", + "\n", + "{table_names}\n", + "\n", + "Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.\"\"\"\n", + "table_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n", + "table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})" + ] + }, + { + "cell_type": "markdown", + "id": "1641dbba-d359-4cb2-ac52-82dfae99f392", + "metadata": {}, + "source": [ + "This works pretty well! Except, as we'll see below, we actually need a few other tables as well. This would be pretty difficult for the model to know based just on the user question. In this case, we might think to simplify our model's job by grouping the tables together. We'll just ask the model to choose between categories \"Music\" and \"Business\", and then take care of selecting all the relevant tables from there:" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "0ccb0bf5-c580-428f-9cde-a58772ae784e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Table(name='Music')]" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "system = f\"\"\"Return the names of the SQL tables that are relevant to the user question. \\\n", + "The tables are:\n", + "\n", + "Music\n", + "Business\"\"\"\n", + "category_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n", + "category_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "ae4899fc-6f8a-4b10-983c-9e3fef4a7bb9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import List\n", + "\n", + "\n", + "def get_tables(categories: List[Table]) -> List[str]:\n", + " tables = []\n", + " for category in categories:\n", + " if category.name == \"Music\":\n", + " tables.extend(\n", + " [\n", + " \"Album\",\n", + " \"Artist\",\n", + " \"Genre\",\n", + " \"MediaType\",\n", + " \"Playlist\",\n", + " \"PlaylistTrack\",\n", + " \"Track\",\n", + " ]\n", + " )\n", + " elif category.name == \"Business\":\n", + " tables.extend([\"Customer\", \"Employee\", \"Invoice\", \"InvoiceLine\"])\n", + " return tables\n", + "\n", + "\n", + "table_chain = category_chain | get_tables # noqa\n", + "table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})" + ] + }, + { + "cell_type": "markdown", + "id": "04d52d01-1ccf-4753-b34a-0dcbc4921f78", + "metadata": {}, + "source": [ + "Now that we've got a chain that can output the relevant tables for any query we can combine this with our [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html), which can accept a list of `table_names_to_use` to determine which table schemas are included in the prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "79f2a5a2-eb99-47e3-9c2b-e5751a800174", + "metadata": {}, + "outputs": [], + "source": [ + "from operator import itemgetter\n", + "\n", + "from langchain.chains import create_sql_query_chain\n", + "from langchain_core.runnables import RunnablePassthrough\n", + "\n", + "query_chain = create_sql_query_chain(llm, db)\n", + "# Convert \"question\" key to the \"input\" key expected by current table_chain.\n", + "table_chain = {\"input\": itemgetter(\"question\")} | table_chain\n", + "# Set table_names_to_use using table_chain.\n", + "full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "424a7564-f63c-4584-b734-88021926486d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SELECT \"Genre\".\"Name\"\n", + "FROM \"Genre\"\n", + "JOIN \"Track\" ON \"Genre\".\"GenreId\" = \"Track\".\"GenreId\"\n", + "JOIN \"Album\" ON \"Track\".\"AlbumId\" = \"Album\".\"AlbumId\"\n", + "JOIN \"Artist\" ON \"Album\".\"ArtistId\" = \"Artist\".\"ArtistId\"\n", + "WHERE \"Artist\".\"Name\" = 'Alanis Morissette'\n" + ] + } + ], + "source": [ + "query = full_chain.invoke(\n", + " {\"question\": \"What are all the genres of Alanis Morisette songs\"}\n", + ")\n", + "print(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "3fb715cf-69d1-46a6-a1a7-9715ee550a0c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"[('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',)]\"" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db.run(query)" + ] + }, + { + "cell_type": "markdown", + "id": "bb3d12b0-81a6-4250-8bc4-d58fe762c4cc", + "metadata": {}, + "source": [ + "We might rephrase our question slightly to remove redundancy in the answer" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "010b5c3c-d55b-461a-8de5-8f1a8b2c56ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SELECT DISTINCT g.Name\n", + "FROM Genre g\n", + "JOIN Track t ON g.GenreId = t.GenreId\n", + "JOIN Album a ON t.AlbumId = a.AlbumId\n", + "JOIN Artist ar ON a.ArtistId = ar.ArtistId\n", + "WHERE ar.Name = 'Alanis Morissette'\n" + ] + } + ], + "source": [ + "query = full_chain.invoke(\n", + " {\"question\": \"What is the set of all unique genres of Alanis Morisette songs\"}\n", + ")\n", + "print(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "d21c0563-1f55-4577-8222-b0e9802f1c4b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"[('Rock',)]\"" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db.run(query)" + ] + }, + { + "cell_type": "markdown", + "id": "7a717020-84c2-40f3-ba84-6624138d8e0c", + "metadata": {}, + "source": [ + "We can see the [LangSmith trace](https://smith.langchain.com/public/20b8ef90-1dac-4754-90f0-6bc11203c50a/r) for this run here.\n", + "\n", + "We've seen how to dynamically include a subset of table schemas in a prompt within a chain. Another possible approach to this problem is to let an Agent decide for itself when to look up tables by giving it a Tool to do so. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide." + ] + }, + { + "cell_type": "markdown", + "id": "cb9e54fd-64ca-4ed5-847c-afc635aae4f5", + "metadata": {}, + "source": [ + "## High-cardinality columns\n", + "\n", + "In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n", + "\n", + "One naive strategy it to create a vector store with all the distinct proper nouns that exist in the database. We can then query that vector store each user input and inject the most relevant proper nouns into the prompt.\n", + "\n", + "First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "dee1b9e1-36b0-4cc1-ab78-7a872ad87e29", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import ast\n", + "import re\n", + "\n", + "\n", + "def query_as_list(db, query):\n", + " res = db.run(query)\n", + " res = [el for sub in ast.literal_eval(res) for el in sub if el]\n", + " res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n", + " return res\n", + "\n", + "\n", + "proper_nouns = query_as_list(db, \"SELECT Name FROM Artist\")\n", + "proper_nouns += query_as_list(db, \"SELECT Title FROM Album\")\n", + "proper_nouns += query_as_list(db, \"SELECT Name FROM Genre\")\n", + "len(proper_nouns)\n", + "proper_nouns[:5]" + ] + }, + { + "cell_type": "markdown", + "id": "22efa968-1879-4d7a-858f-7899dfa57454", + "metadata": {}, + "source": [ + "Now we can embed and store all of our values in a vector database:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ea50abce-545a-4dc3-8795-8d364f7d142a", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.vectorstores import FAISS\n", + "from langchain_openai import OpenAIEmbeddings\n", + "\n", + "vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())\n", + "retriever = vector_db.as_retriever(search_kwargs={\"k\": 15})" + ] + }, + { + "cell_type": "markdown", + "id": "a5d1d5c0-0928-40a4-b961-f1afe03cd5d3", + "metadata": {}, + "source": [ + "And put together a query construction chain that first retrieves values from the database and inserts them into the prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "aea123ae-d809-44a0-be5d-d883c60d6a11", + "metadata": {}, + "outputs": [], + "source": [ + "from operator import itemgetter\n", + "\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_core.runnables import RunnablePassthrough\n", + "\n", + "system = \"\"\"You are a SQLite expert. Given an input question, create a syntactically \\\n", + "correct SQLite query to run. Unless otherwise specificed, do not return more than \\\n", + "{top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nHere is a non-exhaustive \\\n", + "list of possible feature values. If filtering on a feature value make sure to check its spelling \\\n", + "against this list first:\\n\\n{proper_nouns}\"\"\"\n", + "\n", + "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\")])\n", + "\n", + "query_chain = create_sql_query_chain(llm, db, prompt=prompt)\n", + "retriever_chain = (\n", + " itemgetter(\"question\")\n", + " | retriever\n", + " | (lambda docs: \"\\n\".join(doc.page_content for doc in docs))\n", + ")\n", + "chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain" + ] + }, + { + "cell_type": "markdown", + "id": "12b0ed60-2536-4f82-85df-e096a272072a", + "metadata": {}, + "source": [ + "To try out our chain, let's see what happens when we try filtering on \"elenis moriset\", a mispelling of Alanis Morissette, without and with retrieval:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fcdd8432-07a4-4609-8214-b1591dd94950", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SELECT DISTINCT Genre.Name\n", + "FROM Genre\n", + "JOIN Track ON Genre.GenreId = Track.GenreId\n", + "JOIN Album ON Track.AlbumId = Album.AlbumId\n", + "JOIN Artist ON Album.ArtistId = Artist.ArtistId\n", + "WHERE Artist.Name = 'Elenis Moriset'\n" + ] + }, + { + "data": { + "text/plain": [ + "''" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Without retrieval\n", + "query = query_chain.invoke(\n", + " {\"question\": \"What are all the genres of elenis moriset songs\", \"proper_nouns\": \"\"}\n", + ")\n", + "print(query)\n", + "db.run(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e8a3231a-8590-46f5-a954-da06829ee6df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SELECT DISTINCT Genre.Name\n", + "FROM Genre\n", + "JOIN Track ON Genre.GenreId = Track.GenreId\n", + "JOIN Album ON Track.AlbumId = Album.AlbumId\n", + "JOIN Artist ON Album.ArtistId = Artist.ArtistId\n", + "WHERE Artist.Name = 'Alanis Morissette'\n" + ] + }, + { + "data": { + "text/plain": [ + "\"[('Rock',)]\"" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# With retrieval\n", + "query = chain.invoke({\"question\": \"What are all the genres of elenis moriset songs\"})\n", + "print(query)\n", + "db.run(query)" + ] + }, + { + "cell_type": "markdown", + "id": "7f99181b-a75c-4ff3-b37b-33f99a506581", + "metadata": {}, + "source": [ + "We can see that with retrieval we're able to correct the spelling and get back a valid result.\n", + "\n", + "Another possible approach to this problem is to let an Agent decide for itself when to look up proper nouns. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv", + "language": "python", + "name": "poetry-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/use_cases/sql/prompting.ipynb b/docs/docs/use_cases/sql/prompting.ipynb new file mode 100644 index 0000000000000..27a60b3de9d7e --- /dev/null +++ b/docs/docs/use_cases/sql/prompting.ipynb @@ -0,0 +1,789 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_position: 2\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prompting strategies\n", + "\n", + "In this guide we'll go over prompting strategies to improve SQL query generation. We'll largely focus on methods for getting relevant database-specific information in your prompt.\n", + "\n", + "## Setup\n", + "\n", + "First, get required packages and set environment variables:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain langchain-community langchain-experimental langchain-openai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", + "\n", + "# Uncomment the below to use LangSmith. Not required.\n", + "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", + "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", + "\n", + "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", + "* Run `sqlite3 Chinook.db`\n", + "* Run `.read Chinook_Sqlite.sql`\n", + "* Test `SELECT * FROM Artist LIMIT 10;`\n", + "\n", + "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sqlite\n", + "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" + ] + }, + { + "data": { + "text/plain": [ + "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.utilities import SQLDatabase\n", + "\n", + "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\", sample_rows_in_table_info=3)\n", + "print(db.dialect)\n", + "print(db.get_usable_table_names())\n", + "db.run(\"SELECT * FROM Artist LIMIT 10;\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dialect-specific prompting\n", + "\n", + "One of the simplest things we can do is make our prompt specific to the SQL dialect we're using. When using the built-in [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html) and [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html), this is handled for you for any of the following dialects:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['crate',\n", + " 'duckdb',\n", + " 'googlesql',\n", + " 'mssql',\n", + " 'mysql',\n", + " 'mariadb',\n", + " 'oracle',\n", + " 'postgresql',\n", + " 'sqlite',\n", + " 'clickhouse',\n", + " 'prestodb']" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.chains.sql_database.prompt import SQL_PROMPTS\n", + "\n", + "list(SQL_PROMPTS)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For example, using our current DB we can see that we'll get a SQLite-specific prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n", + "Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n", + "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", + "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", + "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", + "\n", + "Use the following format:\n", + "\n", + "Question: Question here\n", + "SQLQuery: SQL Query to run\n", + "SQLResult: Result of the SQLQuery\n", + "Answer: Final answer here\n", + "\n", + "Only use the following tables:\n", + "\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n", + "\n", + "Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" + ] + } + ], + "source": [ + "from langchain.chains import create_sql_query_chain\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=\"0\")\n", + "chain = create_sql_query_chain(llm, db)\n", + "chain.get_prompts()[0].pretty_print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table definitions and example rows\n", + "\n", + "In basically any SQL chain, we'll need to feed the model at least part of the database schema. Without this it won't be able to write valid queries. Our database comes with some convenience methods to give us the relevant context. Specifically, we can get the table names, their schemas, and a sample of rows from each table:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['table_info', 'table_names']\n", + "\n", + "CREATE TABLE \"Album\" (\n", + "\t\"AlbumId\" INTEGER NOT NULL, \n", + "\t\"Title\" NVARCHAR(160) NOT NULL, \n", + "\t\"ArtistId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"AlbumId\"), \n", + "\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Album table:\n", + "AlbumId\tTitle\tArtistId\n", + "1\tFor Those About To Rock We Salute You\t1\n", + "2\tBalls to the Wall\t2\n", + "3\tRestless and Wild\t2\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Artist\" (\n", + "\t\"ArtistId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(120), \n", + "\tPRIMARY KEY (\"ArtistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Artist table:\n", + "ArtistId\tName\n", + "1\tAC/DC\n", + "2\tAccept\n", + "3\tAerosmith\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Customer\" (\n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n", + "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", + "\t\"Company\" NVARCHAR(80), \n", + "\t\"Address\" NVARCHAR(70), \n", + "\t\"City\" NVARCHAR(40), \n", + "\t\"State\" NVARCHAR(40), \n", + "\t\"Country\" NVARCHAR(40), \n", + "\t\"PostalCode\" NVARCHAR(10), \n", + "\t\"Phone\" NVARCHAR(24), \n", + "\t\"Fax\" NVARCHAR(24), \n", + "\t\"Email\" NVARCHAR(60) NOT NULL, \n", + "\t\"SupportRepId\" INTEGER, \n", + "\tPRIMARY KEY (\"CustomerId\"), \n", + "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Customer table:\n", + "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n", + "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n", + "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n", + "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Employee\" (\n", + "\t\"EmployeeId\" INTEGER NOT NULL, \n", + "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", + "\t\"FirstName\" NVARCHAR(20) NOT NULL, \n", + "\t\"Title\" NVARCHAR(30), \n", + "\t\"ReportsTo\" INTEGER, \n", + "\t\"BirthDate\" DATETIME, \n", + "\t\"HireDate\" DATETIME, \n", + "\t\"Address\" NVARCHAR(70), \n", + "\t\"City\" NVARCHAR(40), \n", + "\t\"State\" NVARCHAR(40), \n", + "\t\"Country\" NVARCHAR(40), \n", + "\t\"PostalCode\" NVARCHAR(10), \n", + "\t\"Phone\" NVARCHAR(24), \n", + "\t\"Fax\" NVARCHAR(24), \n", + "\t\"Email\" NVARCHAR(60), \n", + "\tPRIMARY KEY (\"EmployeeId\"), \n", + "\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Employee table:\n", + "EmployeeId\tLastName\tFirstName\tTitle\tReportsTo\tBirthDate\tHireDate\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\n", + "1\tAdams\tAndrew\tGeneral Manager\tNone\t1962-02-18 00:00:00\t2002-08-14 00:00:00\t11120 Jasper Ave NW\tEdmonton\tAB\tCanada\tT5K 2N1\t+1 (780) 428-9482\t+1 (780) 428-3457\tandrew@chinookcorp.com\n", + "2\tEdwards\tNancy\tSales Manager\t1\t1958-12-08 00:00:00\t2002-05-01 00:00:00\t825 8 Ave SW\tCalgary\tAB\tCanada\tT2P 2T3\t+1 (403) 262-3443\t+1 (403) 262-3322\tnancy@chinookcorp.com\n", + "3\tPeacock\tJane\tSales Support Agent\t2\t1973-08-29 00:00:00\t2002-04-01 00:00:00\t1111 6 Ave SW\tCalgary\tAB\tCanada\tT2P 5M5\t+1 (403) 262-3443\t+1 (403) 262-6712\tjane@chinookcorp.com\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Genre\" (\n", + "\t\"GenreId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(120), \n", + "\tPRIMARY KEY (\"GenreId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Genre table:\n", + "GenreId\tName\n", + "1\tRock\n", + "2\tJazz\n", + "3\tMetal\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Invoice\" (\n", + "\t\"InvoiceId\" INTEGER NOT NULL, \n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"InvoiceDate\" DATETIME NOT NULL, \n", + "\t\"BillingAddress\" NVARCHAR(70), \n", + "\t\"BillingCity\" NVARCHAR(40), \n", + "\t\"BillingState\" NVARCHAR(40), \n", + "\t\"BillingCountry\" NVARCHAR(40), \n", + "\t\"BillingPostalCode\" NVARCHAR(10), \n", + "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", + "\tPRIMARY KEY (\"InvoiceId\"), \n", + "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Invoice table:\n", + "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", + "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", + "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", + "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"InvoiceLine\" (\n", + "\t\"InvoiceLineId\" INTEGER NOT NULL, \n", + "\t\"InvoiceId\" INTEGER NOT NULL, \n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n", + "\t\"Quantity\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"InvoiceLineId\"), \n", + "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", + "\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from InvoiceLine table:\n", + "InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n", + "1\t1\t2\t0.99\t1\n", + "2\t1\t4\t0.99\t1\n", + "3\t2\t6\t0.99\t1\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"MediaType\" (\n", + "\t\"MediaTypeId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(120), \n", + "\tPRIMARY KEY (\"MediaTypeId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from MediaType table:\n", + "MediaTypeId\tName\n", + "1\tMPEG audio file\n", + "2\tProtected AAC audio file\n", + "3\tProtected MPEG-4 video file\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Playlist\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(120), \n", + "\tPRIMARY KEY (\"PlaylistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Playlist table:\n", + "PlaylistId\tName\n", + "1\tMusic\n", + "2\tMovies\n", + "3\tTV Shows\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"PlaylistTrack\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", + "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", + "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from PlaylistTrack table:\n", + "PlaylistId\tTrackId\n", + "1\t3402\n", + "1\t3389\n", + "1\t3390\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Track\" (\n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(200) NOT NULL, \n", + "\t\"AlbumId\" INTEGER, \n", + "\t\"MediaTypeId\" INTEGER NOT NULL, \n", + "\t\"GenreId\" INTEGER, \n", + "\t\"Composer\" NVARCHAR(220), \n", + "\t\"Milliseconds\" INTEGER NOT NULL, \n", + "\t\"Bytes\" INTEGER, \n", + "\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n", + "\tPRIMARY KEY (\"TrackId\"), \n", + "\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n", + "\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n", + "\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Track table:\n", + "TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n", + "1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n", + "2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n", + "3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99\n", + "*/\n" + ] + } + ], + "source": [ + "context = db.get_context()\n", + "print(list(context))\n", + "print(context[\"table_info\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When we don't have too many, or too wide of, tables, we can just insert the entirety of this information in our prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n", + "Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n", + "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", + "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", + "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", + "\n", + "Use the following format:\n", + "\n", + "Question: Question here\n", + "SQLQuery: SQL Query to run\n", + "SQLResult: Result of the SQLQuery\n", + "Answer: Final answer here\n", + "\n", + "Only use the following tables:\n", + "\n", + "CREATE TABLE \"Album\" (\n", + "\t\"AlbumId\" INTEGER NOT NULL, \n", + "\t\"Title\" NVARCHAR(160) NOT NULL, \n", + "\t\"ArtistId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"AlbumId\"), \n", + "\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Album table:\n", + "AlbumId\tTitle\tArtistId\n", + "1\tFor Those About To Rock We Salute You\t1\n", + "2\tBalls to the Wall\t2\n", + "3\tRestless and Wild\t2\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Artist\" (\n", + "\t\"ArtistId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(120)\n" + ] + } + ], + "source": [ + "prompt_with_context = chain.get_prompts()[0].partial(table_info=context[\"table_info\"])\n", + "print(prompt_with_context.pretty_repr()[:1500])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When we do have database schemas that are too large to fit into our model's context window, we'll need to come up with ways of inserting only the relevant table definitions into the prompt based on the user input. For more on this head to the [Many tables, wide tables, high-cardinality feature](/docs/use_cases/sql/large_db) guide." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Few-shot examples\n", + "\n", + "Including examples of natural language questions being converted to valid SQL queries against our database in the prompt will often improve model performance, especially for complex queries.\n", + "\n", + "Let's say we have the following examples:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "examples = [\n", + " {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n", + " {\n", + " \"input\": \"Find all albums for the artist 'AC/DC'.\",\n", + " \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n", + " },\n", + " {\n", + " \"input\": \"List all tracks in the 'Rock' genre.\",\n", + " \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n", + " },\n", + " {\n", + " \"input\": \"Find the total duration of all tracks.\",\n", + " \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n", + " },\n", + " {\n", + " \"input\": \"List all customers from Canada.\",\n", + " \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n", + " },\n", + " {\n", + " \"input\": \"How many tracks are there in the album with ID 5?\",\n", + " \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n", + " },\n", + " {\n", + " \"input\": \"Find the total number of invoices.\",\n", + " \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n", + " },\n", + " {\n", + " \"input\": \"List all tracks that are longer than 5 minutes.\",\n", + " \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n", + " },\n", + " {\n", + " \"input\": \"Who are the top 5 customers by total purchase?\",\n", + " \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n", + " },\n", + " {\n", + " \"input\": \"Which albums are from the year 2000?\",\n", + " \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n", + " },\n", + " {\n", + " \"input\": \"How many employees are there\",\n", + " \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n", + " },\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can create a few-shot prompt with them like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate\n", + "\n", + "example_prompt = PromptTemplate.from_template(\"User input: {input}\\nSQL query: {query}\")\n", + "prompt = FewShotPromptTemplate(\n", + " examples=examples[:5],\n", + " example_prompt=example_prompt,\n", + " prefix=\"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nBelow are a number of examples of questions and their corresponding SQL queries.\",\n", + " suffix=\"User input: {input}\\nSQL query: \",\n", + " input_variables=[\"input\", \"top_k\", \"table_info\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.\n", + "\n", + "Here is the relevant table info: foo\n", + "\n", + "Below are a number of examples of questions and their corresponding SQL queries.\n", + "\n", + "User input: List all artists.\n", + "SQL query: SELECT * FROM Artist;\n", + "\n", + "User input: Find all albums for the artist 'AC/DC'.\n", + "SQL query: SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\n", + "\n", + "User input: List all tracks in the 'Rock' genre.\n", + "SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n", + "\n", + "User input: Find the total duration of all tracks.\n", + "SQL query: SELECT SUM(Milliseconds) FROM Track;\n", + "\n", + "User input: List all customers from Canada.\n", + "SQL query: SELECT * FROM Customer WHERE Country = 'Canada';\n", + "\n", + "User input: How many artists are there?\n", + "SQL query: \n" + ] + } + ], + "source": [ + "print(prompt.format(input=\"How many artists are there?\", top_k=3, table_info=\"foo\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dynamic few-shot examples\n", + "\n", + "If we have enough examples, we may want to only include the most relevant ones in the prompt, either because they don't fit in the model's context window or because the long tail of examples distracts the model. And specifically, given any input we want to include the examples most relevant to that input.\n", + "\n", + "We can do just this using an ExampleSelector. In this case we'll use a [SemanticSimilarityExampleSelector](https://api.python.langchain.com/en/latest/example_selectors/langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.html), which will store the examples in the vector database of our choosing. At runtime it will perform a similarity search between the input and our examples, and return the most semantically similar ones: " + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.vectorstores import FAISS\n", + "from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n", + "from langchain_openai import OpenAIEmbeddings\n", + "\n", + "example_selector = SemanticSimilarityExampleSelector.from_examples(\n", + " examples,\n", + " OpenAIEmbeddings(),\n", + " FAISS,\n", + " k=5,\n", + " input_keys=[\"input\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'input': 'List all artists.', 'query': 'SELECT * FROM Artist;'},\n", + " {'input': 'How many employees are there',\n", + " 'query': 'SELECT COUNT(*) FROM \"Employee\"'},\n", + " {'input': 'How many tracks are there in the album with ID 5?',\n", + " 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'},\n", + " {'input': 'Which albums are from the year 2000?',\n", + " 'query': \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\"},\n", + " {'input': \"List all tracks in the 'Rock' genre.\",\n", + " 'query': \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\"}]" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_selector.select_examples({\"input\": \"how many artists are there?\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use it, we can pass the ExampleSelector directly in to our FewShotPromptTemplate:" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = FewShotPromptTemplate(\n", + " example_selector=example_selector,\n", + " example_prompt=example_prompt,\n", + " prefix=\"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nBelow are a number of examples of questions and their corresponding SQL queries.\",\n", + " suffix=\"User input: {input}\\nSQL query: \",\n", + " input_variables=[\"input\", \"top_k\", \"table_info\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.\n", + "\n", + "Here is the relevant table info: foo\n", + "\n", + "Below are a number of examples of questions and their corresponding SQL queries.\n", + "\n", + "User input: List all artists.\n", + "SQL query: SELECT * FROM Artist;\n", + "\n", + "User input: How many employees are there\n", + "SQL query: SELECT COUNT(*) FROM \"Employee\"\n", + "\n", + "User input: How many tracks are there in the album with ID 5?\n", + "SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n", + "\n", + "User input: Which albums are from the year 2000?\n", + "SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n", + "\n", + "User input: List all tracks in the 'Rock' genre.\n", + "SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n", + "\n", + "User input: how many artists are there?\n", + "SQL query: \n" + ] + } + ], + "source": [ + "print(prompt.format(input=\"how many artists are there?\", top_k=3, table_info=\"foo\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'SELECT COUNT(*) FROM Artist;'" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain = create_sql_query_chain(llm, db, prompt)\n", + "chain.invoke({\"question\": \"how many artists are there?\"})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv", + "language": "python", + "name": "poetry-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/docs/use_cases/sql/query_checking.ipynb b/docs/docs/use_cases/sql/query_checking.ipynb new file mode 100644 index 0000000000000..fcb5d44a2ba30 --- /dev/null +++ b/docs/docs/use_cases/sql/query_checking.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "494149c1-9a1a-4b75-8982-6bb19cc5e14e", + "metadata": {}, + "source": [ + "---\n", + "sidebar_position: 3\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "4da7ae91-4973-4e97-a570-fa24024ec65d", + "metadata": {}, + "source": [ + "# Query validation\n", + "\n", + "Perhaps the most error-prone part of any SQL chain or agent is writing valid and safe SQL queries. In this guide we'll go over some strategies for validating our queries and handling invalid queries.\n", + "\n", + "## Setup\n", + "\n", + "First, get required packages and set environment variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d40d5bc-3647-4b5d-808a-db470d40fe7a", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain langchain-community langchain-openai" + ] + }, + { + "cell_type": "markdown", + "id": "c998536a-b1ff-46e7-ac51-dc6deb55d22b", + "metadata": {}, + "source": [ + "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71f46270-e1c6-45b4-b36e-ea2e9f860eba", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", + "\n", + "# Uncomment the below to use LangSmith. Not required.\n", + "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", + "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" + ] + }, + { + "cell_type": "markdown", + "id": "a0a2151b-cecf-4559-92a1-ca48824fed18", + "metadata": {}, + "source": [ + "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", + "\n", + "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", + "* Run `sqlite3 Chinook.db`\n", + "* Run `.read Chinook_Sqlite.sql`\n", + "* Test `SELECT * FROM Artist LIMIT 10;`\n", + "\n", + "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8cedc936-5268-4bfa-b838-bdcc1ee9573c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sqlite\n", + "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" + ] + }, + { + "data": { + "text/plain": [ + "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.utilities import SQLDatabase\n", + "\n", + "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", + "print(db.dialect)\n", + "print(db.get_usable_table_names())\n", + "db.run(\"SELECT * FROM Artist LIMIT 10;\")" + ] + }, + { + "cell_type": "markdown", + "id": "2d203315-fab7-4621-80da-41e9bf82d803", + "metadata": {}, + "source": [ + "## Query checker\n", + "\n", + "Perhaps the simplest strategy is to ask the model itself to check the original query for common mistakes. Suppose we have the following SQL query chain:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec66bb76-b1ad-48ad-a7d4-b518e9421b86", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import create_sql_query_chain\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", + "chain = create_sql_query_chain(llm, db)" + ] + }, + { + "cell_type": "markdown", + "id": "da01023d-cc05-43e3-a38d-ed9d56d3ad15", + "metadata": {}, + "source": [ + "And we want to validate its outputs. We can do so by extending the chain with a second prompt and model call:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "16686750-d8ee-4c60-8d67-b28281cb6164", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "system = \"\"\"Double check the user's {dialect} query for common mistakes, including:\n", + "- Using NOT IN with NULL values\n", + "- Using UNION when UNION ALL should have been used\n", + "- Using BETWEEN for exclusive ranges\n", + "- Data type mismatch in predicates\n", + "- Properly quoting identifiers\n", + "- Using the correct number of arguments for functions\n", + "- Casting to the correct data type\n", + "- Using the proper columns for joins\n", + "\n", + "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n", + "\n", + "Output the final SQL query only.\"\"\"\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", system), (\"human\", \"{query}\")]\n", + ").partial(dialect=db.dialect)\n", + "validation_chain = prompt | llm | StrOutputParser()\n", + "\n", + "full_chain = {\"query\": chain} | validation_chain" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "3a910260-205d-4f4e-afc6-9477572dc947", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"SELECT AVG(Invoice.Total) AS AverageInvoice\\nFROM Invoice\\nJOIN Customer ON Invoice.CustomerId = Customer.CustomerId\\nWHERE Customer.Country = 'USA'\\nAND Customer.Fax IS NULL\\nAND Invoice.InvoiceDate >= '2003-01-01'\\nAND Invoice.InvoiceDate < '2010-01-01'\"" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = full_chain.invoke(\n", + " {\n", + " \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n", + " }\n", + ")\n", + "query" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "d01d78b5-89a0-4c12-b743-707ebe64ba86", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[(6.632999999999998,)]'" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db.run(query)" + ] + }, + { + "cell_type": "markdown", + "id": "6e133526-26bd-49da-9cfa-7adc0e59fd72", + "metadata": {}, + "source": [ + "The obvious downside of this approach is that we need to make two model calls instead of one to generate our query. To get around this we can try to perform the query generation and query check in a single model invocation:" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "7af0030a-549e-4e69-9298-3d0a038c2fdd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================\u001b[1m System Message \u001b[0m================================\n", + "\n", + "You are a \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m expert. Given an input question, creat a syntactically correct \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query to run.\n", + "Unless the user specifies in the question a specific number of examples to obtain, query for at most \u001b[33;1m\u001b[1;3m{top_k}\u001b[0m results using the LIMIT clause as per \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m. You can order the results to return the most informative data in the database.\n", + "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", + "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", + "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", + "\n", + "Only use the following tables:\n", + "\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n", + "\n", + "Write an initial draft of the query. Then double check the \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query for common mistakes, including:\n", + "- Using NOT IN with NULL values\n", + "- Using UNION when UNION ALL should have been used\n", + "- Using BETWEEN for exclusive ranges\n", + "- Data type mismatch in predicates\n", + "- Properly quoting identifiers\n", + "- Using the correct number of arguments for functions\n", + "- Casting to the correct data type\n", + "- Using the proper columns for joins\n", + "\n", + "Use format:\n", + "\n", + "First draft: <>\n", + "Final answer: <>\n", + "\n", + "\n", + "================================\u001b[1m Human Message \u001b[0m=================================\n", + "\n", + "\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" + ] + } + ], + "source": [ + "system = \"\"\"You are a {dialect} expert. Given an input question, creat a syntactically correct {dialect} query to run.\n", + "Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.\n", + "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", + "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", + "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", + "\n", + "Only use the following tables:\n", + "{table_info}\n", + "\n", + "Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:\n", + "- Using NOT IN with NULL values\n", + "- Using UNION when UNION ALL should have been used\n", + "- Using BETWEEN for exclusive ranges\n", + "- Data type mismatch in predicates\n", + "- Properly quoting identifiers\n", + "- Using the correct number of arguments for functions\n", + "- Casting to the correct data type\n", + "- Using the proper columns for joins\n", + "\n", + "Use format:\n", + "\n", + "First draft: <>\n", + "Final answer: <>\n", + "\"\"\"\n", + "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\")]).partial(dialect=db.dialect)\n", + "\n", + "def parse_final_answer(output: str) -> str:\n", + " return output.split(\"Final answer: \")[1]\n", + " \n", + "chain = create_sql_query_chain(llm, db, prompt=prompt) | parse_final_answer\n", + "prompt.pretty_print()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "806e27a2-e511-45ea-a4ed-8ce8fa6e1d58", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"\\nSELECT AVG(i.Total) AS AverageInvoice\\nFROM Invoice i\\nJOIN Customer c ON i.CustomerId = c.CustomerId\\nWHERE c.Country = 'USA' AND c.Fax IS NULL AND i.InvoiceDate >= date('2003-01-01') AND i.InvoiceDate < date('2010-01-01')\"" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = chain.invoke(\n", + " {\n", + " \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n", + " }\n", + ")\n", + "query" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "70fff2fa-1f86-4f83-9fd2-e87a5234d329", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[(6.632999999999998,)]'" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db.run(query)" + ] + }, + { + "cell_type": "markdown", + "id": "fc8af115-7c23-421a-8fd7-29bf1b6687a4", + "metadata": {}, + "source": [ + "## Human-in-the-loop\n", + "\n", + "In some cases our data is sensitive enough that we never want to execute a SQL query without a human approving it first. Head to the [Tool use: Human-in-the-loop](/docs/use_cases/tool_use/human_in_the_loop) page to learn how to add a human-in-the-loop to any tool, chain or agent.\n", + "\n", + "## Error handling\n", + "\n", + "At some point, the model will make a mistake and craft an invalid SQL query. Or an issue will arise with our database. Or the model API will go down. We'll want to add some error handling behavior to our chains and agents so that we fail gracefully in these situations, and perhaps even automatically recover. To learn about error handling with tools, head to the [Tool use: Error handling](/docs/use_cases/tool_use/tool_error_handling) page." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/use_cases/sql/quickstart.ipynb b/docs/docs/use_cases/sql/quickstart.ipynb new file mode 100644 index 0000000000000..490700a45197b --- /dev/null +++ b/docs/docs/use_cases/sql/quickstart.ipynb @@ -0,0 +1,603 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_position: 0\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart\n", + "\n", + "In this guide we'll go over the basic ways to create a Q&A chain and agent over a SQL database. These systems will allow us to ask a question about the data in a SQL database and get back a natural language answer. The main difference between the two is that our agent can query the database in a loop as many time as it needs to answer the question.\n", + "\n", + "## ⚠️ Security note ⚠️\n", + "\n", + "Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n", + "\n", + "\n", + "## Architecture\n", + "\n", + "At a high-level, the steps of any SQL chain and agent are:\n", + "\n", + "1. **Convert question to SQL query**: Model converts user input to a SQL query.\n", + "2. **Execute SQL query**: Execute the SQL query.\n", + "3. **Answer the question**: Model responds to user input using the query results.\n", + "\n", + "\n", + "![sql_usecase.png](../../../static/img/sql_usecase.png)\n", + "\n", + "## Setup\n", + "\n", + "First, get required packages and set environment variables:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain langchain-community langchain-openai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We default to OpenAI models in this guide." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", + "\n", + "# Uncomment the below to use LangSmith. Not required.\n", + "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", + "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", + "\n", + "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", + "* Run `sqlite3 Chinook.db`\n", + "* Run `.read Chinook_Sqlite.sql`\n", + "* Test `SELECT * FROM Artist LIMIT 10;`\n", + "\n", + "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sqlite\n", + "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" + ] + }, + { + "data": { + "text/plain": [ + "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.utilities import SQLDatabase\n", + "\n", + "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", + "print(db.dialect)\n", + "print(db.get_usable_table_names())\n", + "db.run(\"SELECT * FROM Artist LIMIT 10;\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great! We've got a SQL database that we can query. Now let's try hooking it up to an LLM.\n", + "\n", + "## Chain\n", + "\n", + "Let's create a simple chain that takes a question, turns it into a SQL query, executes the query, and uses the result to answer the original question.\n", + "\n", + "### Convert question to SQL query\n", + "\n", + "The first step in a SQL chain or agent is to take the user input and convert it to a SQL query. LangChain comes with a built-in chain for this: [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'SELECT COUNT(*) FROM Employee'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.chains import create_sql_query_chain\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", + "chain = create_sql_query_chain(llm, db)\n", + "response = chain.invoke({\"question\": \"How many employees are there\"})\n", + "response" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can execute the query to make sure it's valid:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[(8,)]'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db.run(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can look at the [LangSmith trace](https://smith.langchain.com/public/c8fa52ea-be46-4829-bde2-52894970b830/r) to get a better understanding of what this chain is doing. We can also inpect the chain directly for its prompts. Looking at the prompt (below), we can see that it is:\n", + "\n", + "* Dialect-specific. In this case it references SQLite explicitly.\n", + "* Has definitions for all the available tables.\n", + "* Has three examples rows for each table.\n", + "\n", + "This technique is inspired by papers like [this](https://arxiv.org/pdf/2204.00498.pdf), which suggest showing examples rows and being explicit about tables improves performance. We can also in" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n", + "Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n", + "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", + "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", + "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", + "\n", + "Use the following format:\n", + "\n", + "Question: Question here\n", + "SQLQuery: SQL Query to run\n", + "SQLResult: Result of the SQLQuery\n", + "Answer: Final answer here\n", + "\n", + "Only use the following tables:\n", + "\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n", + "\n", + "Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" + ] + } + ], + "source": [ + "chain.get_prompts()[0].pretty_print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Execute SQL query\n", + "\n", + "Now that we've generated a SQL query, we'll want to execute it. **This is the most dangerous part of creating a SQL chain.** Consider carefully if it is OK to run automated queries over your data. Minimize the database connection permissions as much as possible. Consider adding a human approval step to you chains before query execution (see below).\n", + "\n", + "We can use the `QuerySQLDatabaseTool` to easily add query execution to our chain:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[(8,)]'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n", + "\n", + "execute_query = QuerySQLDataBaseTool(db=db)\n", + "write_query = create_sql_query_chain(llm, db)\n", + "chain = write_query | execute_query\n", + "chain.invoke({\"question\": \"How many employees are there\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Answer the question\n", + "\n", + "Now that we've got a way to automatically generate and execute queries, we just need to combine the original question and SQL query result to generate a final answer. We can do this by passing question and result to the LLM once more:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'There are 8 employees.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from operator import itemgetter\n", + "\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.prompts import PromptTemplate\n", + "from langchain_core.runnables import RunnablePassthrough\n", + "\n", + "answer_prompt = PromptTemplate.from_template(\n", + " \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n", + "\n", + "Question: {question}\n", + "SQL Query: {query}\n", + "SQL Result: {result}\n", + "Answer: \"\"\"\n", + ")\n", + "\n", + "answer = answer_prompt | llm | StrOutputParser()\n", + "chain = (\n", + " RunnablePassthrough.assign(query=write_query).assign(\n", + " result=itemgetter(\"query\") | execute_query\n", + " )\n", + " | answer\n", + ")\n", + "\n", + "chain.invoke({\"question\": \"How many employees are there\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Next steps\n", + "\n", + "For more complex query-generation, we may want to create few-shot prompts or add query-checking steps. For advanced techniques like this and more check out:\n", + "\n", + "* [Prompting strategies](/docs/use_cases/sql/prompting): Advanced prompt engineering techniques.\n", + "* [Query checking](/docs/use_cases/sql/query_checking): Add query validation and error handling.\n", + "* [Large databses](/docs/use_cases/sql/large_db): Techniques for working with large databases." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Agents\n", + "\n", + "LangChain has an SQL Agent which provides a more flexible way of interacting with SQL databases. The main advantages of using the SQL Agent are:\n", + "\n", + "- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n", + "- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n", + "- It can answer questions that require multiple dependent queries.\n", + "\n", + "To initialize the agent, we use `create_sql_agent` function. This agent contains the `SQLDatabaseToolkit` which contains tools to: \n", + "\n", + "* Create and execute queries\n", + "* Check query syntax\n", + "* Retrieve table descriptions\n", + "* ... and more\n", + "\n", + "### Initializing agent" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.agent_toolkits import create_sql_agent\n", + "\n", + "agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_list_tables` with `{}`\n", + "\n", + "\n", + "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_schema` with `Invoice,Customer`\n", + "\n", + "\n", + "\u001b[0m\u001b[33;1m\u001b[1;3m\n", + "CREATE TABLE \"Customer\" (\n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n", + "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", + "\t\"Company\" NVARCHAR(80), \n", + "\t\"Address\" NVARCHAR(70), \n", + "\t\"City\" NVARCHAR(40), \n", + "\t\"State\" NVARCHAR(40), \n", + "\t\"Country\" NVARCHAR(40), \n", + "\t\"PostalCode\" NVARCHAR(10), \n", + "\t\"Phone\" NVARCHAR(24), \n", + "\t\"Fax\" NVARCHAR(24), \n", + "\t\"Email\" NVARCHAR(60) NOT NULL, \n", + "\t\"SupportRepId\" INTEGER, \n", + "\tPRIMARY KEY (\"CustomerId\"), \n", + "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Customer table:\n", + "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n", + "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n", + "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n", + "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Invoice\" (\n", + "\t\"InvoiceId\" INTEGER NOT NULL, \n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"InvoiceDate\" DATETIME NOT NULL, \n", + "\t\"BillingAddress\" NVARCHAR(70), \n", + "\t\"BillingCity\" NVARCHAR(40), \n", + "\t\"BillingState\" NVARCHAR(40), \n", + "\t\"BillingCountry\" NVARCHAR(40), \n", + "\t\"BillingPostalCode\" NVARCHAR(10), \n", + "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", + "\tPRIMARY KEY (\"InvoiceId\"), \n", + "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Invoice table:\n", + "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", + "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", + "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", + "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", + "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`\n", + "responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n", + "\n", + "Here is the SQL query:\n", + "\n", + "```sql\n", + "SELECT c.Country, SUM(i.Total) AS TotalSales\n", + "FROM Invoice i\n", + "JOIN Customer c ON i.CustomerId = c.CustomerId\n", + "GROUP BY c.Country\n", + "ORDER BY TotalSales DESC\n", + "LIMIT 10;\n", + "```\n", + "\n", + "Now, I will execute this query to get the total sales per country.\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n", + "\n", + "1. USA: $523.06\n", + "2. Canada: $303.96\n", + "3. France: $195.10\n", + "4. Brazil: $190.10\n", + "5. Germany: $156.48\n", + "6. United Kingdom: $112.86\n", + "7. Czech Republic: $90.24\n", + "8. Portugal: $77.24\n", + "9. India: $75.26\n", + "10. Chile: $46.62\n", + "\n", + "To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': \"List the total sales per country. Which country's customers spent the most?\",\n", + " 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke(\n", + " {\n", + " \"input\": \"List the total sales per country. Which country's customers spent the most?\"\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_list_tables` with `{}`\n", + "\n", + "\n", + "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Invoking: `sql_db_schema` with `PlaylistTrack`\n", + "\n", + "\n", + "\u001b[0m\u001b[33;1m\u001b[1;3m\n", + "CREATE TABLE \"PlaylistTrack\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", + "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", + "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from PlaylistTrack table:\n", + "PlaylistId\tTrackId\n", + "1\t3402\n", + "1\t3389\n", + "1\t3390\n", + "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n", + "\n", + "Here is the schema of the `PlaylistTrack` table:\n", + "\n", + "```\n", + "CREATE TABLE \"PlaylistTrack\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"TrackId\" INTEGER NOT NULL, \n", + "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", + "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", + "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", + ")\n", + "```\n", + "\n", + "The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n", + "\n", + "Here are three sample rows from the `PlaylistTrack` table:\n", + "\n", + "```\n", + "PlaylistId TrackId\n", + "1 3402\n", + "1 3389\n", + "1 3390\n", + "```\n", + "\n", + "Please let me know if there is anything else I can help with.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'Describe the playlisttrack table',\n", + " 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke({\"input\": \"Describe the playlisttrack table\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Next steps\n", + "\n", + "For more on how to use and customize agents head to the [Agents](/docs/use_cases/sql/agents) page." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "poetry-venv", + "language": "python", + "name": "poetry-venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 21d423d9a8fac6031ca956f7b039814dfb7dbb93 Mon Sep 17 00:00:00 2001 From: Piotr Mardziel Date: Mon, 22 Jan 2024 11:34:13 -0800 Subject: [PATCH 03/13] core[patch]: preserve inspect.iscoroutinefunction with @deprecated decorator (#16295) Adjusted `deprecate` decorator to make sure decorated async functions are still recognized as "coroutinefunction" by `inspect`. Before change, functions such as `LLMChain.acall` which are decorated as deprecated are not recognized as coroutine functions. After the change, they are recognized: ```python import inspect from langchain import LLMChain # Is false before change but true after. inspect.iscoroutinefunction(LLMChain.acall) ``` --- .../tests/unit_tests/_api/test_deprecation.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/libs/core/tests/unit_tests/_api/test_deprecation.py b/libs/core/tests/unit_tests/_api/test_deprecation.py index fc05c002575da..dbfd2554aabb5 100644 --- a/libs/core/tests/unit_tests/_api/test_deprecation.py +++ b/libs/core/tests/unit_tests/_api/test_deprecation.py @@ -155,6 +155,30 @@ async def test_deprecated_async_function() -> None: assert inspect.iscoroutinefunction(deprecated_async_function) + assert not inspect.iscoroutinefunction(deprecated_function) + + +@pytest.mark.asyncio +async def test_deprecated_async_function() -> None: + """Test deprecated async function.""" + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + assert ( + await deprecated_async_function() == "This is a deprecated async function." + ) + assert len(warning_list) == 1 + warning = warning_list[0].message + assert str(warning) == ( + "The function `deprecated_async_function` was deprecated " + "in LangChain 2.0.0 and will be removed in 3.0.0" + ) + + doc = deprecated_function.__doc__ + assert isinstance(doc, str) + assert doc.startswith("[*Deprecated*] original doc") + + assert inspect.iscoroutinefunction(deprecated_async_function) + def test_deprecated_method() -> None: """Test deprecated method.""" @@ -198,6 +222,31 @@ async def test_deprecated_async_method() -> None: assert inspect.iscoroutinefunction(obj.deprecated_async_method) + assert not inspect.iscoroutinefunction(obj.deprecated_method) + + +@pytest.mark.asyncio +async def test_deprecated_async_method() -> None: + """Test deprecated async method.""" + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + obj = ClassWithDeprecatedMethods() + assert ( + await obj.deprecated_async_method() == "This is a deprecated async method." + ) + assert len(warning_list) == 1 + warning = warning_list[0].message + assert str(warning) == ( + "The function `deprecated_async_method` was deprecated in " + "LangChain 2.0.0 and will be removed in 3.0.0" + ) + + doc = obj.deprecated_method.__doc__ + assert isinstance(doc, str) + assert doc.startswith("[*Deprecated*] original doc") + + assert inspect.iscoroutinefunction(obj.deprecated_async_method) + def test_deprecated_classmethod() -> None: """Test deprecated classmethod.""" From c68c8183c34d9bd4915c734faf1214b44ba74972 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Tue, 23 Jan 2024 06:25:26 +0800 Subject: [PATCH 04/13] docs: add milvus multitenancy doc (#16177) - **Description:** add milvus multitenancy doc, it is an example for this [pr](https://github.com/langchain-ai/langchain/pull/15740) . - **Issue:** No, - **Dependencies:** No, - **Twitter handle:** No Signed-off-by: ChengZi --- .../integrations/vectorstores/milvus.ipynb | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/docs/docs/integrations/vectorstores/milvus.ipynb b/docs/docs/integrations/vectorstores/milvus.ipynb index 24cb43b436bdb..bfd7fac629321 100644 --- a/docs/docs/integrations/vectorstores/milvus.ipynb +++ b/docs/docs/integrations/vectorstores/milvus.ipynb @@ -367,6 +367,120 @@ "]\n", "upserted_pks = vector_db.upsert(pks, new_docs)" ] + }, + { + "cell_type": "markdown", + "source": [ + "### Per-User Retrieval\n", + "\n", + "When building a retrieval app, you often have to build it with multiple users in mind. This means that you may be storing data not just for one user, but for many different users, and they should not be able to see eachother’s data.\n", + "\n", + "Milvus recommends using [partition_key](https://milvus.io/docs/multi_tenancy.md#Partition-key-based-multi-tenancy) to implement multi-tenancy, here is an example." + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "from langchain_core.documents import Document\n", + "\n", + "docs = [\n", + " Document(page_content=\"i worked at kensho\", metadata={\"namespace\": \"harrison\"}),\n", + " Document(page_content=\"i worked at facebook\", metadata={\"namespace\": \"ankush\"}),\n", + "]\n", + "vectorstore = Milvus.from_documents(\n", + " docs,\n", + " embeddings,\n", + " connection_args={\"host\": \"127.0.0.1\", \"port\": \"19530\"},\n", + " drop_old=True,\n", + " partition_key_field=\"namespace\", # Use the \"namespace\" field as the partition key\n", + ")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "To conduct a search using the partition key, you should include either of the following in the boolean expression of the search request:\n", + "\n", + "`search_kwargs={\"expr\": ' == \"xxxx\"'}`\n", + "\n", + "`search_kwargs={\"expr\": ' == in [\"xxx\", \"xxx\"]'}`\n", + "\n", + "Do replace `` with the name of the field that is designated as the partition key.\n", + "\n", + "Milvus changes to a partition based on the specified partition key, filters entities according to the partition key, and searches among the filtered entities.\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "data": { + "text/plain": "[Document(page_content='i worked at facebook', metadata={'namespace': 'ankush'})]" + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This will only get documents for Ankush\n", + "vectorstore.as_retriever(\n", + " search_kwargs={\"expr\": 'namespace == \"ankush\"'}\n", + ").get_relevant_documents(\"where did i work?\")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "data": { + "text/plain": "[Document(page_content='i worked at kensho', metadata={'namespace': 'harrison'})]" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This will only get documents for Harrison\n", + "vectorstore.as_retriever(\n", + " search_kwargs={\"expr\": 'namespace == \"harrison\"'}\n", + ").get_relevant_documents(\"where did i work?\")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } } ], "metadata": { From 2b3efff71515dae809a6c84eebe2d7aa56bccd8b Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 27 May 2024 14:50:53 +0200 Subject: [PATCH 05/13] Add sql docstore --- .../integrations/vectorstores/milvus.ipynb | 116 +-- docs/docs/use_cases/sql/agents.ipynb | 815 ------------------ docs/docs/use_cases/sql/index.ipynb | 68 -- docs/docs/use_cases/sql/large_db.ipynb | 627 -------------- docs/docs/use_cases/sql/prompting.ipynb | 789 ----------------- docs/docs/use_cases/sql/query_checking.ipynb | 389 --------- docs/docs/use_cases/sql/quickstart.ipynb | 603 ------------- .../langchain_community/storage/__init__.py | 5 + .../storage/sql_docstore.py | 236 ----- .../tests/unit_tests/storage/test_imports.py | 1 + .../tests/unit_tests/storage/test_sql.py | 89 ++ .../tests/unit_tests/test_dependencies.py | 1 + .../tests/unit_tests/_api/test_deprecation.py | 49 -- 13 files changed, 97 insertions(+), 3691 deletions(-) delete mode 100644 docs/docs/use_cases/sql/agents.ipynb delete mode 100644 docs/docs/use_cases/sql/index.ipynb delete mode 100644 docs/docs/use_cases/sql/large_db.ipynb delete mode 100644 docs/docs/use_cases/sql/prompting.ipynb delete mode 100644 docs/docs/use_cases/sql/query_checking.ipynb delete mode 100644 docs/docs/use_cases/sql/quickstart.ipynb delete mode 100644 libs/community/langchain_community/storage/sql_docstore.py create mode 100644 libs/community/tests/unit_tests/storage/test_sql.py diff --git a/docs/docs/integrations/vectorstores/milvus.ipynb b/docs/docs/integrations/vectorstores/milvus.ipynb index bfd7fac629321..4c314dfa15f0b 100644 --- a/docs/docs/integrations/vectorstores/milvus.ipynb +++ b/docs/docs/integrations/vectorstores/milvus.ipynb @@ -367,120 +367,6 @@ "]\n", "upserted_pks = vector_db.upsert(pks, new_docs)" ] - }, - { - "cell_type": "markdown", - "source": [ - "### Per-User Retrieval\n", - "\n", - "When building a retrieval app, you often have to build it with multiple users in mind. This means that you may be storing data not just for one user, but for many different users, and they should not be able to see eachother’s data.\n", - "\n", - "Milvus recommends using [partition_key](https://milvus.io/docs/multi_tenancy.md#Partition-key-based-multi-tenancy) to implement multi-tenancy, here is an example." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 2, - "outputs": [], - "source": [ - "from langchain_core.documents import Document\n", - "\n", - "docs = [\n", - " Document(page_content=\"i worked at kensho\", metadata={\"namespace\": \"harrison\"}),\n", - " Document(page_content=\"i worked at facebook\", metadata={\"namespace\": \"ankush\"}),\n", - "]\n", - "vectorstore = Milvus.from_documents(\n", - " docs,\n", - " embeddings,\n", - " connection_args={\"host\": \"127.0.0.1\", \"port\": \"19530\"},\n", - " drop_old=True,\n", - " partition_key_field=\"namespace\", # Use the \"namespace\" field as the partition key\n", - ")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "To conduct a search using the partition key, you should include either of the following in the boolean expression of the search request:\n", - "\n", - "`search_kwargs={\"expr\": ' == \"xxxx\"'}`\n", - "\n", - "`search_kwargs={\"expr\": ' == in [\"xxx\", \"xxx\"]'}`\n", - "\n", - "Do replace `` with the name of the field that is designated as the partition key.\n", - "\n", - "Milvus changes to a partition based on the specified partition key, filters entities according to the partition key, and searches among the filtered entities.\n" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 3, - "outputs": [ - { - "data": { - "text/plain": "[Document(page_content='i worked at facebook', metadata={'namespace': 'ankush'})]" - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# This will only get documents for Ankush\n", - "vectorstore.as_retriever(\n", - " search_kwargs={\"expr\": 'namespace == \"ankush\"'}\n", - ").get_relevant_documents(\"where did i work?\")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 4, - "outputs": [ - { - "data": { - "text/plain": "[Document(page_content='i worked at kensho', metadata={'namespace': 'harrison'})]" - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# This will only get documents for Harrison\n", - "vectorstore.as_retriever(\n", - " search_kwargs={\"expr\": 'namespace == \"harrison\"'}\n", - ").get_relevant_documents(\"where did i work?\")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } } ], "metadata": { @@ -504,4 +390,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/docs/use_cases/sql/agents.ipynb b/docs/docs/use_cases/sql/agents.ipynb deleted file mode 100644 index aa6db0dd3f920..0000000000000 --- a/docs/docs/use_cases/sql/agents.ipynb +++ /dev/null @@ -1,815 +0,0 @@ -{ - "cells": [ - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "---\n", - "sidebar_position: 1\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Agents\n", - "\n", - "LangChain has a SQL Agent which provides a more flexible way of interacting with SQL Databases than a chain. The main advantages of using the SQL Agent are:\n", - "\n", - "- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n", - "- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n", - "- It can query the database as many times as needed to answer the user question.\n", - "\n", - "To initialize the agent we'll use the [create_sql_agent](https://api.python.langchain.com/en/latest/agent_toolkits/langchain_community.agent_toolkits.sql.base.create_sql_agent.html) constructor. This agent uses the `SQLDatabaseToolkit` which contains tools to: \n", - "\n", - "* Create and execute queries\n", - "* Check query syntax\n", - "* Retrieve table descriptions\n", - "* ... and more\n", - "\n", - "## Setup\n", - "\n", - "First, get required packages and set environment variables:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", - "\n", - "# Uncomment the below to use LangSmith. Not required.\n", - "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", - "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", - "\n", - "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", - "* Run `sqlite3 Chinook.db`\n", - "* Run `.read Chinook_Sqlite.sql`\n", - "* Test `SELECT * FROM Artist LIMIT 10;`\n", - "\n", - "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "sqlite\n", - "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" - ] - }, - { - "data": { - "text/plain": [ - "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain_community.utilities import SQLDatabase\n", - "\n", - "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", - "print(db.dialect)\n", - "print(db.get_usable_table_names())\n", - "db.run(\"SELECT * FROM Artist LIMIT 10;\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Agent\n", - "\n", - "We'll use an OpenAI chat model and an `\"openai-tools\"` agent, which will use OpenAI's function-calling API to drive the agent's tool selection and invocations." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.agent_toolkits import create_sql_agent\n", - "from langchain_openai import ChatOpenAI\n", - "\n", - "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", - "agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_list_tables` with `{}`\n", - "\n", - "\n", - "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_schema` with `Invoice,Customer`\n", - "\n", - "\n", - "\u001b[0m\u001b[33;1m\u001b[1;3m\n", - "CREATE TABLE \"Customer\" (\n", - "\t\"CustomerId\" INTEGER NOT NULL, \n", - "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n", - "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", - "\t\"Company\" NVARCHAR(80), \n", - "\t\"Address\" NVARCHAR(70), \n", - "\t\"City\" NVARCHAR(40), \n", - "\t\"State\" NVARCHAR(40), \n", - "\t\"Country\" NVARCHAR(40), \n", - "\t\"PostalCode\" NVARCHAR(10), \n", - "\t\"Phone\" NVARCHAR(24), \n", - "\t\"Fax\" NVARCHAR(24), \n", - "\t\"Email\" NVARCHAR(60) NOT NULL, \n", - "\t\"SupportRepId\" INTEGER, \n", - "\tPRIMARY KEY (\"CustomerId\"), \n", - "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Customer table:\n", - "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n", - "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n", - "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n", - "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Invoice\" (\n", - "\t\"InvoiceId\" INTEGER NOT NULL, \n", - "\t\"CustomerId\" INTEGER NOT NULL, \n", - "\t\"InvoiceDate\" DATETIME NOT NULL, \n", - "\t\"BillingAddress\" NVARCHAR(70), \n", - "\t\"BillingCity\" NVARCHAR(40), \n", - "\t\"BillingState\" NVARCHAR(40), \n", - "\t\"BillingCountry\" NVARCHAR(40), \n", - "\t\"BillingPostalCode\" NVARCHAR(10), \n", - "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", - "\tPRIMARY KEY (\"InvoiceId\"), \n", - "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Invoice table:\n", - "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", - "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", - "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", - "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", - "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC`\n", - "responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n", - "\n", - "Here is the SQL query:\n", - "\n", - "```sql\n", - "SELECT c.Country, SUM(i.Total) AS TotalSales\n", - "FROM Invoice i\n", - "JOIN Customer c ON i.CustomerId = c.CustomerId\n", - "GROUP BY c.Country\n", - "ORDER BY TotalSales DESC\n", - "```\n", - "\n", - "Now, I will execute this query to get the results.\n", - "\n", - "\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n", - "\n", - "1. USA: $523.06\n", - "2. Canada: $303.96\n", - "3. France: $195.10\n", - "4. Brazil: $190.10\n", - "5. Germany: $156.48\n", - "6. United Kingdom: $112.86\n", - "7. Czech Republic: $90.24\n", - "8. Portugal: $77.24\n", - "9. India: $75.26\n", - "10. Chile: $46.62\n", - "\n", - "The country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': \"List the total sales per country. Which country's customers spent the most?\",\n", - " 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nThe country whose customers spent the most is the USA, with a total sales of $523.06.'}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_executor.invoke(\n", - " \"List the total sales per country. Which country's customers spent the most?\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_list_tables` with `{}`\n", - "\n", - "\n", - "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_schema` with `PlaylistTrack`\n", - "\n", - "\n", - "\u001b[0m\u001b[33;1m\u001b[1;3m\n", - "CREATE TABLE \"PlaylistTrack\" (\n", - "\t\"PlaylistId\" INTEGER NOT NULL, \n", - "\t\"TrackId\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", - "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", - "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from PlaylistTrack table:\n", - "PlaylistId\tTrackId\n", - "1\t3402\n", - "1\t3389\n", - "1\t3390\n", - "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n", - "\n", - "Here is the schema of the `PlaylistTrack` table:\n", - "\n", - "```\n", - "CREATE TABLE \"PlaylistTrack\" (\n", - "\t\"PlaylistId\" INTEGER NOT NULL, \n", - "\t\"TrackId\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", - "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", - "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", - ")\n", - "```\n", - "\n", - "The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n", - "\n", - "Here are three sample rows from the `PlaylistTrack` table:\n", - "\n", - "```\n", - "PlaylistId TrackId\n", - "1 3402\n", - "1 3389\n", - "1 3390\n", - "```\n", - "\n", - "Please let me know if there is anything else I can help with.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': 'Describe the playlisttrack table',\n", - " 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_executor.invoke(\"Describe the playlisttrack table\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Using a dynamic few-shot prompt\n", - "\n", - "To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input.\n", - "\n", - "First we need some user input <> SQL query examples:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "examples = [\n", - " {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n", - " {\n", - " \"input\": \"Find all albums for the artist 'AC/DC'.\",\n", - " \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n", - " },\n", - " {\n", - " \"input\": \"List all tracks in the 'Rock' genre.\",\n", - " \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n", - " },\n", - " {\n", - " \"input\": \"Find the total duration of all tracks.\",\n", - " \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n", - " },\n", - " {\n", - " \"input\": \"List all customers from Canada.\",\n", - " \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n", - " },\n", - " {\n", - " \"input\": \"How many tracks are there in the album with ID 5?\",\n", - " \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n", - " },\n", - " {\n", - " \"input\": \"Find the total number of invoices.\",\n", - " \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n", - " },\n", - " {\n", - " \"input\": \"List all tracks that are longer than 5 minutes.\",\n", - " \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n", - " },\n", - " {\n", - " \"input\": \"Who are the top 5 customers by total purchase?\",\n", - " \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n", - " },\n", - " {\n", - " \"input\": \"Which albums are from the year 2000?\",\n", - " \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n", - " },\n", - " {\n", - " \"input\": \"How many employees are there\",\n", - " \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n", - " },\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can create an example selector. This will take the actual user input and select some number of examples to add to our few-shot prompt. We'll use a SemanticSimilarityExampleSelector, which will perform a semantic search using the embeddings and vector store we configure to find the examples most similar to our input:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.vectorstores import FAISS\n", - "from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n", - "from langchain_openai import OpenAIEmbeddings\n", - "\n", - "example_selector = SemanticSimilarityExampleSelector.from_examples(\n", - " examples,\n", - " OpenAIEmbeddings(),\n", - " FAISS,\n", - " k=5,\n", - " input_keys=[\"input\"],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can create our FewShotPromptTemplate, which takes our example selector, an example prompt for formatting each example, and a string prefix and suffix to put before and after our formatted examples:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.prompts import (\n", - " ChatPromptTemplate,\n", - " FewShotPromptTemplate,\n", - " MessagesPlaceholder,\n", - " PromptTemplate,\n", - " SystemMessagePromptTemplate,\n", - ")\n", - "\n", - "system_prefix = \"\"\"You are an agent designed to interact with a SQL database.\n", - "Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n", - "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n", - "You can order the results by a relevant column to return the most interesting examples in the database.\n", - "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", - "You have access to tools for interacting with the database.\n", - "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n", - "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n", - "\n", - "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n", - "\n", - "If the question does not seem related to the database, just return \"I don't know\" as the answer.\n", - "\n", - "Here are some examples of user inputs and their corresponding SQL queries:\"\"\"\n", - "\n", - "few_shot_prompt = FewShotPromptTemplate(\n", - " example_selector=example_selector,\n", - " example_prompt=PromptTemplate.from_template(\n", - " \"User input: {input}\\nSQL query: {query}\"\n", - " ),\n", - " input_variables=[\"input\", \"dialect\", \"top_k\"],\n", - " prefix=system_prefix,\n", - " suffix=\"\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Since our underlying agent is an [OpenAI tools agent](/docs/modules/agents/agent_types/openai_tools), which uses OpenAI function calling, our full prompt should be a chat prompt with a human message template and an agent_scratchpad `MessagesPlaceholder`. The few-shot prompt will be used for our system message:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "full_prompt = ChatPromptTemplate.from_messages(\n", - " [\n", - " SystemMessagePromptTemplate(prompt=few_shot_prompt),\n", - " (\"human\", \"{input}\"),\n", - " MessagesPlaceholder(\"agent_scratchpad\"),\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "System: You are an agent designed to interact with a SQL database.\n", - "Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n", - "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n", - "You can order the results by a relevant column to return the most interesting examples in the database.\n", - "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", - "You have access to tools for interacting with the database.\n", - "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n", - "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n", - "\n", - "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n", - "\n", - "If the question does not seem related to the database, just return \"I don't know\" as the answer.\n", - "\n", - "Here are some examples of user inputs and their corresponding SQL queries:\n", - "\n", - "User input: List all artists.\n", - "SQL query: SELECT * FROM Artist;\n", - "\n", - "User input: How many employees are there\n", - "SQL query: SELECT COUNT(*) FROM \"Employee\"\n", - "\n", - "User input: How many tracks are there in the album with ID 5?\n", - "SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n", - "\n", - "User input: List all tracks in the 'Rock' genre.\n", - "SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n", - "\n", - "User input: Which albums are from the year 2000?\n", - "SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n", - "Human: How many arists are there\n" - ] - } - ], - "source": [ - "# Example formatted prompt\n", - "prompt_val = full_prompt.invoke(\n", - " {\n", - " \"input\": \"How many arists are there\",\n", - " \"top_k\": 5,\n", - " \"dialect\": \"SQLite\",\n", - " \"agent_scratchpad\": [],\n", - " }\n", - ")\n", - "print(prompt_val.to_string())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And now we can create our agent with our custom prompt:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "agent = create_sql_agent(\n", - " llm=llm,\n", - " db=db,\n", - " prompt=full_prompt,\n", - " verbose=True,\n", - " agent_type=\"openai-tools\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's try it out:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist'}`\n", - "\n", - "\n", - "\u001b[0m\u001b[36;1m\u001b[1;3m[(275,)]\u001b[0m\u001b[32;1m\u001b[1;3mThere are 275 artists in the database.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': 'How many artists are there?',\n", - " 'output': 'There are 275 artists in the database.'}" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent.invoke({\"input\": \"How many artists are there?\"})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dealing with high-cardinality columns\n", - "\n", - "In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n", - "\n", - "We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n", - "\n", - "First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['For Those About To Rock We Salute You',\n", - " 'Balls to the Wall',\n", - " 'Restless and Wild',\n", - " 'Let There Be Rock',\n", - " 'Big Ones']" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import ast\n", - "import re\n", - "\n", - "\n", - "def query_as_list(db, query):\n", - " res = db.run(query)\n", - " res = [el for sub in ast.literal_eval(res) for el in sub if el]\n", - " res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n", - " return res\n", - "\n", - "\n", - "artists = query_as_list(db, \"SELECT Name FROM Artist\")\n", - "albums = query_as_list(db, \"SELECT Title FROM Album\")\n", - "albums[:5]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can proceed with creating the custom **retriever tool** and the final agent:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.agents.agent_toolkits import create_retriever_tool\n", - "\n", - "vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())\n", - "retriever = vector_db.as_retriever(search_kwargs={\"k\": 5})\n", - "description = \"\"\"Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \\\n", - "valid proper nouns. Use the noun most similar to the search.\"\"\"\n", - "retriever_tool = create_retriever_tool(\n", - " retriever,\n", - " name=\"search_proper_nouns\",\n", - " description=description,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "system = \"\"\"You are an agent designed to interact with a SQL database.\n", - "Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n", - "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n", - "You can order the results by a relevant column to return the most interesting examples in the database.\n", - "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n", - "You have access to tools for interacting with the database.\n", - "Only use the given tools. Only use the information returned by the tools to construct your final answer.\n", - "You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n", - "\n", - "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n", - "\n", - "If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the \"search_proper_nouns\" tool!\n", - "\n", - "You have access to the following tables: {table_names}\n", - "\n", - "If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n", - "\n", - "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\"), MessagesPlaceholder(\"agent_scratchpad\")])\n", - "agent = create_sql_agent(\n", - " llm=llm,\n", - " db=db,\n", - " extra_tools=[retriever_tool],\n", - " prompt=prompt,\n", - " agent_type=\"openai-tools\",\n", - " verbose=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\n", - "Invoking: `search_proper_nouns` with `{'query': 'alice in chains'}`\n", - "\n", - "\n", - "\u001b[0m\u001b[36;1m\u001b[1;3mAlice In Chains\n", - "\n", - "Metallica\n", - "\n", - "Pearl Jam\n", - "\n", - "Pearl Jam\n", - "\n", - "Smashing Pumpkins\u001b[0m\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_query` with `{'query': \"SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')\"}`\n", - "\n", - "\n", - "\u001b[0m\u001b[36;1m\u001b[1;3m[(1,)]\u001b[0m\u001b[32;1m\u001b[1;3mAlice In Chains has 1 album.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': 'How many albums does alice in chains have?',\n", - " 'output': 'Alice In Chains has 1 album.'}" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent.invoke({\"input\": \"How many albums does alice in chains have?\"})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As we can see, the agent used the `search_proper_nouns` tool in order to check how to correctly query the database for this specific artist." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Next steps\n", - "\n", - "Under the hood, `create_sql_agent` is just passing in SQL tools to more generic agent constructors. To learn more about the built-in generic agent types as well as how to build custom agents, head to the [Agents Modules](/docs/modules/agents/).\n", - "\n", - "The built-in `AgentExecutor` runs a simple Agent action -> Tool call -> Agent action... loop. To build more complex agent runtimes, head to the [LangGraph section](/docs/langgraph)." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "poetry-venv", - "language": "python", - "name": "poetry-venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/docs/use_cases/sql/index.ipynb b/docs/docs/use_cases/sql/index.ipynb deleted file mode 100644 index 1b706c831b168..0000000000000 --- a/docs/docs/use_cases/sql/index.ipynb +++ /dev/null @@ -1,68 +0,0 @@ -{ - "cells": [ - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "---\n", - "sidebar_position: 0.5\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# SQL\n", - "\n", - "One of the most common types of databases that we can build Q&A systems for are SQL databases. LangChain comes with a number of built-in chains and agents that are compatible with any SQL dialect supported by SQLAlchemy (e.g., MySQL, PostgreSQL, Oracle SQL, Databricks, SQLite). They enable use cases such as:\n", - "\n", - "* Generating queries that will be run based on natural language questions,\n", - "* Creating chatbots that can answer questions based on database data,\n", - "* Building custom dashboards based on insights a user wants to analyze,\n", - "\n", - "and much more.\n", - "\n", - "## ⚠️ Security note ⚠️\n", - "\n", - "Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n", - "\n", - "![sql_usecase.png](../../../static/img/sql_usecase.png)\n", - "\n", - "## Quickstart\n", - "\n", - "Head to the **[Quickstart](/docs/use_cases/sql/quickstart)** page to get started.\n", - "\n", - "## Advanced\n", - "\n", - "Once you've familiarized yourself with the basics, you can head to the advanced guides:\n", - "\n", - "* [Agents](/docs/use_cases/sql/agents): Building agents that can interact with SQL DBs.\n", - "* [Prompting strategies](/docs/use_cases/sql/prompting): Strategies for improving SQL query generation.\n", - "* [Query validation](/docs/use_cases/sql/query_checking): How to validate SQL queries.\n", - "* [Large databases](/docs/use_cases/sql/large_db): How to interact with DBs with many tables and high-cardinality columns." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "poetry-venv", - "language": "python", - "name": "poetry-venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/docs/use_cases/sql/large_db.ipynb b/docs/docs/use_cases/sql/large_db.ipynb deleted file mode 100644 index dd034037d1a4b..0000000000000 --- a/docs/docs/use_cases/sql/large_db.ipynb +++ /dev/null @@ -1,627 +0,0 @@ -{ - "cells": [ - { - "cell_type": "raw", - "id": "b2788654-3f62-4e2a-ab00-471922cc54df", - "metadata": {}, - "source": [ - "---\n", - "sidebar_position: 4\n", - "---" - ] - }, - { - "cell_type": "markdown", - "id": "6751831d-9b08-434f-829b-d0052a3b119f", - "metadata": {}, - "source": [ - "# Large databases\n", - "\n", - "In order to write valid queries against a database, we need to feed the model the table names, table schemas, and feature values for it to query over. When there are many tables, columns, and/or high-cardinality columns, it becomes impossible for us to dump the full information about our database in every prompt. Instead, we must find ways to dynamically insert into the prompt only the most relevant information. Let's take a look at some techniques for doing this.\n", - "\n", - "\n", - "## Setup\n", - "\n", - "First, get required packages and set environment variables:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "9675e433-e608-469e-b04e-2847479a8310", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai" - ] - }, - { - "cell_type": "markdown", - "id": "4f56ff5d-b2e4-49e3-a0b4-fb99466cfedc", - "metadata": {}, - "source": [ - "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "06d8dd03-2d7b-4fef-b145-43c074eacb8b", - "metadata": {}, - "outputs": [ - { - "name": "stdin", - "output_type": "stream", - "text": [ - " ········\n" - ] - } - ], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "# os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", - "\n", - "# Uncomment the below to use LangSmith. Not required.\n", - "os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", - "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" - ] - }, - { - "cell_type": "markdown", - "id": "590ee096-db88-42af-90d4-99b8149df753", - "metadata": {}, - "source": [ - "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", - "\n", - "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", - "* Run `sqlite3 Chinook.db`\n", - "* Run `.read Chinook_Sqlite.sql`\n", - "* Test `SELECT * FROM Artist LIMIT 10;`\n", - "\n", - "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html) class:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "cebd3915-f58f-4e73-8459-265630ae8cd4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "sqlite\n", - "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" - ] - }, - { - "data": { - "text/plain": [ - "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain_community.utilities import SQLDatabase\n", - "\n", - "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", - "print(db.dialect)\n", - "print(db.get_usable_table_names())\n", - "db.run(\"SELECT * FROM Artist LIMIT 10;\")" - ] - }, - { - "cell_type": "markdown", - "id": "2e572e1f-99b5-46a2-9023-76d1e6256c0a", - "metadata": {}, - "source": [ - "## Many tables\n", - "\n", - "One of the main pieces of information we need to include in our prompt is the schemas of the relevant tables. When we have very many tables, we can't fit all of the schemas in a single prompt. What we can do in such cases is first extract the names of the tables related to the user input, and then include only their schemas.\n", - "\n", - "One easy and reliable way to do this is using OpenAI function-calling and Pydantic models. LangChain comes with a built-in [create_extraction_chain_pydantic](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_tools.extraction.create_extraction_chain_pydantic.html) chain that lets us do just this:" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "d8236886-c54f-4bdb-ad74-2514888628fd", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[Table(name='Genre'), Table(name='Artist'), Table(name='Track')]" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain.chains.openai_tools import create_extraction_chain_pydantic\n", - "from langchain_core.pydantic_v1 import BaseModel, Field\n", - "from langchain_openai import ChatOpenAI\n", - "\n", - "llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n", - "\n", - "\n", - "class Table(BaseModel):\n", - " \"\"\"Table in SQL database.\"\"\"\n", - "\n", - " name: str = Field(description=\"Name of table in SQL database.\")\n", - "\n", - "\n", - "table_names = \"\\n\".join(db.get_usable_table_names())\n", - "system = f\"\"\"Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \\\n", - "The tables are:\n", - "\n", - "{table_names}\n", - "\n", - "Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.\"\"\"\n", - "table_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n", - "table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})" - ] - }, - { - "cell_type": "markdown", - "id": "1641dbba-d359-4cb2-ac52-82dfae99f392", - "metadata": {}, - "source": [ - "This works pretty well! Except, as we'll see below, we actually need a few other tables as well. This would be pretty difficult for the model to know based just on the user question. In this case, we might think to simplify our model's job by grouping the tables together. We'll just ask the model to choose between categories \"Music\" and \"Business\", and then take care of selecting all the relevant tables from there:" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "0ccb0bf5-c580-428f-9cde-a58772ae784e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[Table(name='Music')]" - ] - }, - "execution_count": 59, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "system = f\"\"\"Return the names of the SQL tables that are relevant to the user question. \\\n", - "The tables are:\n", - "\n", - "Music\n", - "Business\"\"\"\n", - "category_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n", - "category_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "ae4899fc-6f8a-4b10-983c-9e3fef4a7bb9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']" - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from typing import List\n", - "\n", - "\n", - "def get_tables(categories: List[Table]) -> List[str]:\n", - " tables = []\n", - " for category in categories:\n", - " if category.name == \"Music\":\n", - " tables.extend(\n", - " [\n", - " \"Album\",\n", - " \"Artist\",\n", - " \"Genre\",\n", - " \"MediaType\",\n", - " \"Playlist\",\n", - " \"PlaylistTrack\",\n", - " \"Track\",\n", - " ]\n", - " )\n", - " elif category.name == \"Business\":\n", - " tables.extend([\"Customer\", \"Employee\", \"Invoice\", \"InvoiceLine\"])\n", - " return tables\n", - "\n", - "\n", - "table_chain = category_chain | get_tables # noqa\n", - "table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})" - ] - }, - { - "cell_type": "markdown", - "id": "04d52d01-1ccf-4753-b34a-0dcbc4921f78", - "metadata": {}, - "source": [ - "Now that we've got a chain that can output the relevant tables for any query we can combine this with our [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html), which can accept a list of `table_names_to_use` to determine which table schemas are included in the prompt:" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "79f2a5a2-eb99-47e3-9c2b-e5751a800174", - "metadata": {}, - "outputs": [], - "source": [ - "from operator import itemgetter\n", - "\n", - "from langchain.chains import create_sql_query_chain\n", - "from langchain_core.runnables import RunnablePassthrough\n", - "\n", - "query_chain = create_sql_query_chain(llm, db)\n", - "# Convert \"question\" key to the \"input\" key expected by current table_chain.\n", - "table_chain = {\"input\": itemgetter(\"question\")} | table_chain\n", - "# Set table_names_to_use using table_chain.\n", - "full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "424a7564-f63c-4584-b734-88021926486d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SELECT \"Genre\".\"Name\"\n", - "FROM \"Genre\"\n", - "JOIN \"Track\" ON \"Genre\".\"GenreId\" = \"Track\".\"GenreId\"\n", - "JOIN \"Album\" ON \"Track\".\"AlbumId\" = \"Album\".\"AlbumId\"\n", - "JOIN \"Artist\" ON \"Album\".\"ArtistId\" = \"Artist\".\"ArtistId\"\n", - "WHERE \"Artist\".\"Name\" = 'Alanis Morissette'\n" - ] - } - ], - "source": [ - "query = full_chain.invoke(\n", - " {\"question\": \"What are all the genres of Alanis Morisette songs\"}\n", - ")\n", - "print(query)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "3fb715cf-69d1-46a6-a1a7-9715ee550a0c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\"[('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',)]\"" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db.run(query)" - ] - }, - { - "cell_type": "markdown", - "id": "bb3d12b0-81a6-4250-8bc4-d58fe762c4cc", - "metadata": {}, - "source": [ - "We might rephrase our question slightly to remove redundancy in the answer" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "010b5c3c-d55b-461a-8de5-8f1a8b2c56ec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SELECT DISTINCT g.Name\n", - "FROM Genre g\n", - "JOIN Track t ON g.GenreId = t.GenreId\n", - "JOIN Album a ON t.AlbumId = a.AlbumId\n", - "JOIN Artist ar ON a.ArtistId = ar.ArtistId\n", - "WHERE ar.Name = 'Alanis Morissette'\n" - ] - } - ], - "source": [ - "query = full_chain.invoke(\n", - " {\"question\": \"What is the set of all unique genres of Alanis Morisette songs\"}\n", - ")\n", - "print(query)" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "d21c0563-1f55-4577-8222-b0e9802f1c4b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\"[('Rock',)]\"" - ] - }, - "execution_count": 58, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db.run(query)" - ] - }, - { - "cell_type": "markdown", - "id": "7a717020-84c2-40f3-ba84-6624138d8e0c", - "metadata": {}, - "source": [ - "We can see the [LangSmith trace](https://smith.langchain.com/public/20b8ef90-1dac-4754-90f0-6bc11203c50a/r) for this run here.\n", - "\n", - "We've seen how to dynamically include a subset of table schemas in a prompt within a chain. Another possible approach to this problem is to let an Agent decide for itself when to look up tables by giving it a Tool to do so. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide." - ] - }, - { - "cell_type": "markdown", - "id": "cb9e54fd-64ca-4ed5-847c-afc635aae4f5", - "metadata": {}, - "source": [ - "## High-cardinality columns\n", - "\n", - "In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n", - "\n", - "One naive strategy it to create a vector store with all the distinct proper nouns that exist in the database. We can then query that vector store each user input and inject the most relevant proper nouns into the prompt.\n", - "\n", - "First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "dee1b9e1-36b0-4cc1-ab78-7a872ad87e29", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import ast\n", - "import re\n", - "\n", - "\n", - "def query_as_list(db, query):\n", - " res = db.run(query)\n", - " res = [el for sub in ast.literal_eval(res) for el in sub if el]\n", - " res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n", - " return res\n", - "\n", - "\n", - "proper_nouns = query_as_list(db, \"SELECT Name FROM Artist\")\n", - "proper_nouns += query_as_list(db, \"SELECT Title FROM Album\")\n", - "proper_nouns += query_as_list(db, \"SELECT Name FROM Genre\")\n", - "len(proper_nouns)\n", - "proper_nouns[:5]" - ] - }, - { - "cell_type": "markdown", - "id": "22efa968-1879-4d7a-858f-7899dfa57454", - "metadata": {}, - "source": [ - "Now we can embed and store all of our values in a vector database:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "ea50abce-545a-4dc3-8795-8d364f7d142a", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.vectorstores import FAISS\n", - "from langchain_openai import OpenAIEmbeddings\n", - "\n", - "vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())\n", - "retriever = vector_db.as_retriever(search_kwargs={\"k\": 15})" - ] - }, - { - "cell_type": "markdown", - "id": "a5d1d5c0-0928-40a4-b961-f1afe03cd5d3", - "metadata": {}, - "source": [ - "And put together a query construction chain that first retrieves values from the database and inserts them into the prompt:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "aea123ae-d809-44a0-be5d-d883c60d6a11", - "metadata": {}, - "outputs": [], - "source": [ - "from operator import itemgetter\n", - "\n", - "from langchain_core.prompts import ChatPromptTemplate\n", - "from langchain_core.runnables import RunnablePassthrough\n", - "\n", - "system = \"\"\"You are a SQLite expert. Given an input question, create a syntactically \\\n", - "correct SQLite query to run. Unless otherwise specificed, do not return more than \\\n", - "{top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nHere is a non-exhaustive \\\n", - "list of possible feature values. If filtering on a feature value make sure to check its spelling \\\n", - "against this list first:\\n\\n{proper_nouns}\"\"\"\n", - "\n", - "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\")])\n", - "\n", - "query_chain = create_sql_query_chain(llm, db, prompt=prompt)\n", - "retriever_chain = (\n", - " itemgetter(\"question\")\n", - " | retriever\n", - " | (lambda docs: \"\\n\".join(doc.page_content for doc in docs))\n", - ")\n", - "chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain" - ] - }, - { - "cell_type": "markdown", - "id": "12b0ed60-2536-4f82-85df-e096a272072a", - "metadata": {}, - "source": [ - "To try out our chain, let's see what happens when we try filtering on \"elenis moriset\", a mispelling of Alanis Morissette, without and with retrieval:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "fcdd8432-07a4-4609-8214-b1591dd94950", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SELECT DISTINCT Genre.Name\n", - "FROM Genre\n", - "JOIN Track ON Genre.GenreId = Track.GenreId\n", - "JOIN Album ON Track.AlbumId = Album.AlbumId\n", - "JOIN Artist ON Album.ArtistId = Artist.ArtistId\n", - "WHERE Artist.Name = 'Elenis Moriset'\n" - ] - }, - { - "data": { - "text/plain": [ - "''" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Without retrieval\n", - "query = query_chain.invoke(\n", - " {\"question\": \"What are all the genres of elenis moriset songs\", \"proper_nouns\": \"\"}\n", - ")\n", - "print(query)\n", - "db.run(query)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "e8a3231a-8590-46f5-a954-da06829ee6df", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SELECT DISTINCT Genre.Name\n", - "FROM Genre\n", - "JOIN Track ON Genre.GenreId = Track.GenreId\n", - "JOIN Album ON Track.AlbumId = Album.AlbumId\n", - "JOIN Artist ON Album.ArtistId = Artist.ArtistId\n", - "WHERE Artist.Name = 'Alanis Morissette'\n" - ] - }, - { - "data": { - "text/plain": [ - "\"[('Rock',)]\"" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# With retrieval\n", - "query = chain.invoke({\"question\": \"What are all the genres of elenis moriset songs\"})\n", - "print(query)\n", - "db.run(query)" - ] - }, - { - "cell_type": "markdown", - "id": "7f99181b-a75c-4ff3-b37b-33f99a506581", - "metadata": {}, - "source": [ - "We can see that with retrieval we're able to correct the spelling and get back a valid result.\n", - "\n", - "Another possible approach to this problem is to let an Agent decide for itself when to look up proper nouns. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "poetry-venv", - "language": "python", - "name": "poetry-venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.1" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/docs/use_cases/sql/prompting.ipynb b/docs/docs/use_cases/sql/prompting.ipynb deleted file mode 100644 index 27a60b3de9d7e..0000000000000 --- a/docs/docs/use_cases/sql/prompting.ipynb +++ /dev/null @@ -1,789 +0,0 @@ -{ - "cells": [ - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "---\n", - "sidebar_position: 2\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Prompting strategies\n", - "\n", - "In this guide we'll go over prompting strategies to improve SQL query generation. We'll largely focus on methods for getting relevant database-specific information in your prompt.\n", - "\n", - "## Setup\n", - "\n", - "First, get required packages and set environment variables:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-experimental langchain-openai" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", - "\n", - "# Uncomment the below to use LangSmith. Not required.\n", - "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", - "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", - "\n", - "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", - "* Run `sqlite3 Chinook.db`\n", - "* Run `.read Chinook_Sqlite.sql`\n", - "* Test `SELECT * FROM Artist LIMIT 10;`\n", - "\n", - "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "sqlite\n", - "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" - ] - }, - { - "data": { - "text/plain": [ - "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain_community.utilities import SQLDatabase\n", - "\n", - "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\", sample_rows_in_table_info=3)\n", - "print(db.dialect)\n", - "print(db.get_usable_table_names())\n", - "db.run(\"SELECT * FROM Artist LIMIT 10;\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dialect-specific prompting\n", - "\n", - "One of the simplest things we can do is make our prompt specific to the SQL dialect we're using. When using the built-in [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html) and [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html), this is handled for you for any of the following dialects:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['crate',\n", - " 'duckdb',\n", - " 'googlesql',\n", - " 'mssql',\n", - " 'mysql',\n", - " 'mariadb',\n", - " 'oracle',\n", - " 'postgresql',\n", - " 'sqlite',\n", - " 'clickhouse',\n", - " 'prestodb']" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain.chains.sql_database.prompt import SQL_PROMPTS\n", - "\n", - "list(SQL_PROMPTS)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For example, using our current DB we can see that we'll get a SQLite-specific prompt:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n", - "Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n", - "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", - "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", - "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", - "\n", - "Use the following format:\n", - "\n", - "Question: Question here\n", - "SQLQuery: SQL Query to run\n", - "SQLResult: Result of the SQLQuery\n", - "Answer: Final answer here\n", - "\n", - "Only use the following tables:\n", - "\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n", - "\n", - "Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" - ] - } - ], - "source": [ - "from langchain.chains import create_sql_query_chain\n", - "from langchain_openai import ChatOpenAI\n", - "\n", - "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=\"0\")\n", - "chain = create_sql_query_chain(llm, db)\n", - "chain.get_prompts()[0].pretty_print()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Table definitions and example rows\n", - "\n", - "In basically any SQL chain, we'll need to feed the model at least part of the database schema. Without this it won't be able to write valid queries. Our database comes with some convenience methods to give us the relevant context. Specifically, we can get the table names, their schemas, and a sample of rows from each table:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['table_info', 'table_names']\n", - "\n", - "CREATE TABLE \"Album\" (\n", - "\t\"AlbumId\" INTEGER NOT NULL, \n", - "\t\"Title\" NVARCHAR(160) NOT NULL, \n", - "\t\"ArtistId\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"AlbumId\"), \n", - "\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Album table:\n", - "AlbumId\tTitle\tArtistId\n", - "1\tFor Those About To Rock We Salute You\t1\n", - "2\tBalls to the Wall\t2\n", - "3\tRestless and Wild\t2\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Artist\" (\n", - "\t\"ArtistId\" INTEGER NOT NULL, \n", - "\t\"Name\" NVARCHAR(120), \n", - "\tPRIMARY KEY (\"ArtistId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Artist table:\n", - "ArtistId\tName\n", - "1\tAC/DC\n", - "2\tAccept\n", - "3\tAerosmith\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Customer\" (\n", - "\t\"CustomerId\" INTEGER NOT NULL, \n", - "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n", - "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", - "\t\"Company\" NVARCHAR(80), \n", - "\t\"Address\" NVARCHAR(70), \n", - "\t\"City\" NVARCHAR(40), \n", - "\t\"State\" NVARCHAR(40), \n", - "\t\"Country\" NVARCHAR(40), \n", - "\t\"PostalCode\" NVARCHAR(10), \n", - "\t\"Phone\" NVARCHAR(24), \n", - "\t\"Fax\" NVARCHAR(24), \n", - "\t\"Email\" NVARCHAR(60) NOT NULL, \n", - "\t\"SupportRepId\" INTEGER, \n", - "\tPRIMARY KEY (\"CustomerId\"), \n", - "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Customer table:\n", - "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n", - "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n", - "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n", - "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Employee\" (\n", - "\t\"EmployeeId\" INTEGER NOT NULL, \n", - "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", - "\t\"FirstName\" NVARCHAR(20) NOT NULL, \n", - "\t\"Title\" NVARCHAR(30), \n", - "\t\"ReportsTo\" INTEGER, \n", - "\t\"BirthDate\" DATETIME, \n", - "\t\"HireDate\" DATETIME, \n", - "\t\"Address\" NVARCHAR(70), \n", - "\t\"City\" NVARCHAR(40), \n", - "\t\"State\" NVARCHAR(40), \n", - "\t\"Country\" NVARCHAR(40), \n", - "\t\"PostalCode\" NVARCHAR(10), \n", - "\t\"Phone\" NVARCHAR(24), \n", - "\t\"Fax\" NVARCHAR(24), \n", - "\t\"Email\" NVARCHAR(60), \n", - "\tPRIMARY KEY (\"EmployeeId\"), \n", - "\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Employee table:\n", - "EmployeeId\tLastName\tFirstName\tTitle\tReportsTo\tBirthDate\tHireDate\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\n", - "1\tAdams\tAndrew\tGeneral Manager\tNone\t1962-02-18 00:00:00\t2002-08-14 00:00:00\t11120 Jasper Ave NW\tEdmonton\tAB\tCanada\tT5K 2N1\t+1 (780) 428-9482\t+1 (780) 428-3457\tandrew@chinookcorp.com\n", - "2\tEdwards\tNancy\tSales Manager\t1\t1958-12-08 00:00:00\t2002-05-01 00:00:00\t825 8 Ave SW\tCalgary\tAB\tCanada\tT2P 2T3\t+1 (403) 262-3443\t+1 (403) 262-3322\tnancy@chinookcorp.com\n", - "3\tPeacock\tJane\tSales Support Agent\t2\t1973-08-29 00:00:00\t2002-04-01 00:00:00\t1111 6 Ave SW\tCalgary\tAB\tCanada\tT2P 5M5\t+1 (403) 262-3443\t+1 (403) 262-6712\tjane@chinookcorp.com\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Genre\" (\n", - "\t\"GenreId\" INTEGER NOT NULL, \n", - "\t\"Name\" NVARCHAR(120), \n", - "\tPRIMARY KEY (\"GenreId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Genre table:\n", - "GenreId\tName\n", - "1\tRock\n", - "2\tJazz\n", - "3\tMetal\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Invoice\" (\n", - "\t\"InvoiceId\" INTEGER NOT NULL, \n", - "\t\"CustomerId\" INTEGER NOT NULL, \n", - "\t\"InvoiceDate\" DATETIME NOT NULL, \n", - "\t\"BillingAddress\" NVARCHAR(70), \n", - "\t\"BillingCity\" NVARCHAR(40), \n", - "\t\"BillingState\" NVARCHAR(40), \n", - "\t\"BillingCountry\" NVARCHAR(40), \n", - "\t\"BillingPostalCode\" NVARCHAR(10), \n", - "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", - "\tPRIMARY KEY (\"InvoiceId\"), \n", - "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Invoice table:\n", - "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", - "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", - "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", - "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"InvoiceLine\" (\n", - "\t\"InvoiceLineId\" INTEGER NOT NULL, \n", - "\t\"InvoiceId\" INTEGER NOT NULL, \n", - "\t\"TrackId\" INTEGER NOT NULL, \n", - "\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n", - "\t\"Quantity\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"InvoiceLineId\"), \n", - "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", - "\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from InvoiceLine table:\n", - "InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n", - "1\t1\t2\t0.99\t1\n", - "2\t1\t4\t0.99\t1\n", - "3\t2\t6\t0.99\t1\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"MediaType\" (\n", - "\t\"MediaTypeId\" INTEGER NOT NULL, \n", - "\t\"Name\" NVARCHAR(120), \n", - "\tPRIMARY KEY (\"MediaTypeId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from MediaType table:\n", - "MediaTypeId\tName\n", - "1\tMPEG audio file\n", - "2\tProtected AAC audio file\n", - "3\tProtected MPEG-4 video file\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Playlist\" (\n", - "\t\"PlaylistId\" INTEGER NOT NULL, \n", - "\t\"Name\" NVARCHAR(120), \n", - "\tPRIMARY KEY (\"PlaylistId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Playlist table:\n", - "PlaylistId\tName\n", - "1\tMusic\n", - "2\tMovies\n", - "3\tTV Shows\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"PlaylistTrack\" (\n", - "\t\"PlaylistId\" INTEGER NOT NULL, \n", - "\t\"TrackId\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", - "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", - "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from PlaylistTrack table:\n", - "PlaylistId\tTrackId\n", - "1\t3402\n", - "1\t3389\n", - "1\t3390\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Track\" (\n", - "\t\"TrackId\" INTEGER NOT NULL, \n", - "\t\"Name\" NVARCHAR(200) NOT NULL, \n", - "\t\"AlbumId\" INTEGER, \n", - "\t\"MediaTypeId\" INTEGER NOT NULL, \n", - "\t\"GenreId\" INTEGER, \n", - "\t\"Composer\" NVARCHAR(220), \n", - "\t\"Milliseconds\" INTEGER NOT NULL, \n", - "\t\"Bytes\" INTEGER, \n", - "\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n", - "\tPRIMARY KEY (\"TrackId\"), \n", - "\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n", - "\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n", - "\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Track table:\n", - "TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n", - "1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n", - "2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n", - "3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99\n", - "*/\n" - ] - } - ], - "source": [ - "context = db.get_context()\n", - "print(list(context))\n", - "print(context[\"table_info\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When we don't have too many, or too wide of, tables, we can just insert the entirety of this information in our prompt:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n", - "Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n", - "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", - "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", - "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", - "\n", - "Use the following format:\n", - "\n", - "Question: Question here\n", - "SQLQuery: SQL Query to run\n", - "SQLResult: Result of the SQLQuery\n", - "Answer: Final answer here\n", - "\n", - "Only use the following tables:\n", - "\n", - "CREATE TABLE \"Album\" (\n", - "\t\"AlbumId\" INTEGER NOT NULL, \n", - "\t\"Title\" NVARCHAR(160) NOT NULL, \n", - "\t\"ArtistId\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"AlbumId\"), \n", - "\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Album table:\n", - "AlbumId\tTitle\tArtistId\n", - "1\tFor Those About To Rock We Salute You\t1\n", - "2\tBalls to the Wall\t2\n", - "3\tRestless and Wild\t2\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Artist\" (\n", - "\t\"ArtistId\" INTEGER NOT NULL, \n", - "\t\"Name\" NVARCHAR(120)\n" - ] - } - ], - "source": [ - "prompt_with_context = chain.get_prompts()[0].partial(table_info=context[\"table_info\"])\n", - "print(prompt_with_context.pretty_repr()[:1500])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When we do have database schemas that are too large to fit into our model's context window, we'll need to come up with ways of inserting only the relevant table definitions into the prompt based on the user input. For more on this head to the [Many tables, wide tables, high-cardinality feature](/docs/use_cases/sql/large_db) guide." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Few-shot examples\n", - "\n", - "Including examples of natural language questions being converted to valid SQL queries against our database in the prompt will often improve model performance, especially for complex queries.\n", - "\n", - "Let's say we have the following examples:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "examples = [\n", - " {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n", - " {\n", - " \"input\": \"Find all albums for the artist 'AC/DC'.\",\n", - " \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n", - " },\n", - " {\n", - " \"input\": \"List all tracks in the 'Rock' genre.\",\n", - " \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n", - " },\n", - " {\n", - " \"input\": \"Find the total duration of all tracks.\",\n", - " \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n", - " },\n", - " {\n", - " \"input\": \"List all customers from Canada.\",\n", - " \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n", - " },\n", - " {\n", - " \"input\": \"How many tracks are there in the album with ID 5?\",\n", - " \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n", - " },\n", - " {\n", - " \"input\": \"Find the total number of invoices.\",\n", - " \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n", - " },\n", - " {\n", - " \"input\": \"List all tracks that are longer than 5 minutes.\",\n", - " \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n", - " },\n", - " {\n", - " \"input\": \"Who are the top 5 customers by total purchase?\",\n", - " \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n", - " },\n", - " {\n", - " \"input\": \"Which albums are from the year 2000?\",\n", - " \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n", - " },\n", - " {\n", - " \"input\": \"How many employees are there\",\n", - " \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n", - " },\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can create a few-shot prompt with them like so:" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate\n", - "\n", - "example_prompt = PromptTemplate.from_template(\"User input: {input}\\nSQL query: {query}\")\n", - "prompt = FewShotPromptTemplate(\n", - " examples=examples[:5],\n", - " example_prompt=example_prompt,\n", - " prefix=\"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nBelow are a number of examples of questions and their corresponding SQL queries.\",\n", - " suffix=\"User input: {input}\\nSQL query: \",\n", - " input_variables=[\"input\", \"top_k\", \"table_info\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.\n", - "\n", - "Here is the relevant table info: foo\n", - "\n", - "Below are a number of examples of questions and their corresponding SQL queries.\n", - "\n", - "User input: List all artists.\n", - "SQL query: SELECT * FROM Artist;\n", - "\n", - "User input: Find all albums for the artist 'AC/DC'.\n", - "SQL query: SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\n", - "\n", - "User input: List all tracks in the 'Rock' genre.\n", - "SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n", - "\n", - "User input: Find the total duration of all tracks.\n", - "SQL query: SELECT SUM(Milliseconds) FROM Track;\n", - "\n", - "User input: List all customers from Canada.\n", - "SQL query: SELECT * FROM Customer WHERE Country = 'Canada';\n", - "\n", - "User input: How many artists are there?\n", - "SQL query: \n" - ] - } - ], - "source": [ - "print(prompt.format(input=\"How many artists are there?\", top_k=3, table_info=\"foo\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dynamic few-shot examples\n", - "\n", - "If we have enough examples, we may want to only include the most relevant ones in the prompt, either because they don't fit in the model's context window or because the long tail of examples distracts the model. And specifically, given any input we want to include the examples most relevant to that input.\n", - "\n", - "We can do just this using an ExampleSelector. In this case we'll use a [SemanticSimilarityExampleSelector](https://api.python.langchain.com/en/latest/example_selectors/langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.html), which will store the examples in the vector database of our choosing. At runtime it will perform a similarity search between the input and our examples, and return the most semantically similar ones: " - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.vectorstores import FAISS\n", - "from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n", - "from langchain_openai import OpenAIEmbeddings\n", - "\n", - "example_selector = SemanticSimilarityExampleSelector.from_examples(\n", - " examples,\n", - " OpenAIEmbeddings(),\n", - " FAISS,\n", - " k=5,\n", - " input_keys=[\"input\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'input': 'List all artists.', 'query': 'SELECT * FROM Artist;'},\n", - " {'input': 'How many employees are there',\n", - " 'query': 'SELECT COUNT(*) FROM \"Employee\"'},\n", - " {'input': 'How many tracks are there in the album with ID 5?',\n", - " 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'},\n", - " {'input': 'Which albums are from the year 2000?',\n", - " 'query': \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\"},\n", - " {'input': \"List all tracks in the 'Rock' genre.\",\n", - " 'query': \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\"}]" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "example_selector.select_examples({\"input\": \"how many artists are there?\"})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To use it, we can pass the ExampleSelector directly in to our FewShotPromptTemplate:" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [], - "source": [ - "prompt = FewShotPromptTemplate(\n", - " example_selector=example_selector,\n", - " example_prompt=example_prompt,\n", - " prefix=\"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nBelow are a number of examples of questions and their corresponding SQL queries.\",\n", - " suffix=\"User input: {input}\\nSQL query: \",\n", - " input_variables=[\"input\", \"top_k\", \"table_info\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.\n", - "\n", - "Here is the relevant table info: foo\n", - "\n", - "Below are a number of examples of questions and their corresponding SQL queries.\n", - "\n", - "User input: List all artists.\n", - "SQL query: SELECT * FROM Artist;\n", - "\n", - "User input: How many employees are there\n", - "SQL query: SELECT COUNT(*) FROM \"Employee\"\n", - "\n", - "User input: How many tracks are there in the album with ID 5?\n", - "SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n", - "\n", - "User input: Which albums are from the year 2000?\n", - "SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n", - "\n", - "User input: List all tracks in the 'Rock' genre.\n", - "SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n", - "\n", - "User input: how many artists are there?\n", - "SQL query: \n" - ] - } - ], - "source": [ - "print(prompt.format(input=\"how many artists are there?\", top_k=3, table_info=\"foo\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'SELECT COUNT(*) FROM Artist;'" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "chain = create_sql_query_chain(llm, db, prompt)\n", - "chain.invoke({\"question\": \"how many artists are there?\"})" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "poetry-venv", - "language": "python", - "name": "poetry-venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/docs/use_cases/sql/query_checking.ipynb b/docs/docs/use_cases/sql/query_checking.ipynb deleted file mode 100644 index fcb5d44a2ba30..0000000000000 --- a/docs/docs/use_cases/sql/query_checking.ipynb +++ /dev/null @@ -1,389 +0,0 @@ -{ - "cells": [ - { - "cell_type": "raw", - "id": "494149c1-9a1a-4b75-8982-6bb19cc5e14e", - "metadata": {}, - "source": [ - "---\n", - "sidebar_position: 3\n", - "---" - ] - }, - { - "cell_type": "markdown", - "id": "4da7ae91-4973-4e97-a570-fa24024ec65d", - "metadata": {}, - "source": [ - "# Query validation\n", - "\n", - "Perhaps the most error-prone part of any SQL chain or agent is writing valid and safe SQL queries. In this guide we'll go over some strategies for validating our queries and handling invalid queries.\n", - "\n", - "## Setup\n", - "\n", - "First, get required packages and set environment variables:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5d40d5bc-3647-4b5d-808a-db470d40fe7a", - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai" - ] - }, - { - "cell_type": "markdown", - "id": "c998536a-b1ff-46e7-ac51-dc6deb55d22b", - "metadata": {}, - "source": [ - "We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "71f46270-e1c6-45b4-b36e-ea2e9f860eba", - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", - "\n", - "# Uncomment the below to use LangSmith. Not required.\n", - "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", - "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" - ] - }, - { - "cell_type": "markdown", - "id": "a0a2151b-cecf-4559-92a1-ca48824fed18", - "metadata": {}, - "source": [ - "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", - "\n", - "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", - "* Run `sqlite3 Chinook.db`\n", - "* Run `.read Chinook_Sqlite.sql`\n", - "* Test `SELECT * FROM Artist LIMIT 10;`\n", - "\n", - "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "8cedc936-5268-4bfa-b838-bdcc1ee9573c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "sqlite\n", - "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" - ] - }, - { - "data": { - "text/plain": [ - "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain_community.utilities import SQLDatabase\n", - "\n", - "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", - "print(db.dialect)\n", - "print(db.get_usable_table_names())\n", - "db.run(\"SELECT * FROM Artist LIMIT 10;\")" - ] - }, - { - "cell_type": "markdown", - "id": "2d203315-fab7-4621-80da-41e9bf82d803", - "metadata": {}, - "source": [ - "## Query checker\n", - "\n", - "Perhaps the simplest strategy is to ask the model itself to check the original query for common mistakes. Suppose we have the following SQL query chain:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec66bb76-b1ad-48ad-a7d4-b518e9421b86", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.chains import create_sql_query_chain\n", - "from langchain_openai import ChatOpenAI\n", - "\n", - "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", - "chain = create_sql_query_chain(llm, db)" - ] - }, - { - "cell_type": "markdown", - "id": "da01023d-cc05-43e3-a38d-ed9d56d3ad15", - "metadata": {}, - "source": [ - "And we want to validate its outputs. We can do so by extending the chain with a second prompt and model call:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "16686750-d8ee-4c60-8d67-b28281cb6164", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.output_parsers import StrOutputParser\n", - "from langchain_core.prompts import ChatPromptTemplate\n", - "\n", - "system = \"\"\"Double check the user's {dialect} query for common mistakes, including:\n", - "- Using NOT IN with NULL values\n", - "- Using UNION when UNION ALL should have been used\n", - "- Using BETWEEN for exclusive ranges\n", - "- Data type mismatch in predicates\n", - "- Properly quoting identifiers\n", - "- Using the correct number of arguments for functions\n", - "- Casting to the correct data type\n", - "- Using the proper columns for joins\n", - "\n", - "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n", - "\n", - "Output the final SQL query only.\"\"\"\n", - "prompt = ChatPromptTemplate.from_messages(\n", - " [(\"system\", system), (\"human\", \"{query}\")]\n", - ").partial(dialect=db.dialect)\n", - "validation_chain = prompt | llm | StrOutputParser()\n", - "\n", - "full_chain = {\"query\": chain} | validation_chain" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "3a910260-205d-4f4e-afc6-9477572dc947", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\"SELECT AVG(Invoice.Total) AS AverageInvoice\\nFROM Invoice\\nJOIN Customer ON Invoice.CustomerId = Customer.CustomerId\\nWHERE Customer.Country = 'USA'\\nAND Customer.Fax IS NULL\\nAND Invoice.InvoiceDate >= '2003-01-01'\\nAND Invoice.InvoiceDate < '2010-01-01'\"" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "query = full_chain.invoke(\n", - " {\n", - " \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n", - " }\n", - ")\n", - "query" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "d01d78b5-89a0-4c12-b743-707ebe64ba86", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'[(6.632999999999998,)]'" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db.run(query)" - ] - }, - { - "cell_type": "markdown", - "id": "6e133526-26bd-49da-9cfa-7adc0e59fd72", - "metadata": {}, - "source": [ - "The obvious downside of this approach is that we need to make two model calls instead of one to generate our query. To get around this we can try to perform the query generation and query check in a single model invocation:" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "7af0030a-549e-4e69-9298-3d0a038c2fdd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "================================\u001b[1m System Message \u001b[0m================================\n", - "\n", - "You are a \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m expert. Given an input question, creat a syntactically correct \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query to run.\n", - "Unless the user specifies in the question a specific number of examples to obtain, query for at most \u001b[33;1m\u001b[1;3m{top_k}\u001b[0m results using the LIMIT clause as per \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m. You can order the results to return the most informative data in the database.\n", - "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", - "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", - "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", - "\n", - "Only use the following tables:\n", - "\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n", - "\n", - "Write an initial draft of the query. Then double check the \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query for common mistakes, including:\n", - "- Using NOT IN with NULL values\n", - "- Using UNION when UNION ALL should have been used\n", - "- Using BETWEEN for exclusive ranges\n", - "- Data type mismatch in predicates\n", - "- Properly quoting identifiers\n", - "- Using the correct number of arguments for functions\n", - "- Casting to the correct data type\n", - "- Using the proper columns for joins\n", - "\n", - "Use format:\n", - "\n", - "First draft: <>\n", - "Final answer: <>\n", - "\n", - "\n", - "================================\u001b[1m Human Message \u001b[0m=================================\n", - "\n", - "\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" - ] - } - ], - "source": [ - "system = \"\"\"You are a {dialect} expert. Given an input question, creat a syntactically correct {dialect} query to run.\n", - "Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.\n", - "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", - "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", - "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", - "\n", - "Only use the following tables:\n", - "{table_info}\n", - "\n", - "Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:\n", - "- Using NOT IN with NULL values\n", - "- Using UNION when UNION ALL should have been used\n", - "- Using BETWEEN for exclusive ranges\n", - "- Data type mismatch in predicates\n", - "- Properly quoting identifiers\n", - "- Using the correct number of arguments for functions\n", - "- Casting to the correct data type\n", - "- Using the proper columns for joins\n", - "\n", - "Use format:\n", - "\n", - "First draft: <>\n", - "Final answer: <>\n", - "\"\"\"\n", - "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\")]).partial(dialect=db.dialect)\n", - "\n", - "def parse_final_answer(output: str) -> str:\n", - " return output.split(\"Final answer: \")[1]\n", - " \n", - "chain = create_sql_query_chain(llm, db, prompt=prompt) | parse_final_answer\n", - "prompt.pretty_print()" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "806e27a2-e511-45ea-a4ed-8ce8fa6e1d58", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\"\\nSELECT AVG(i.Total) AS AverageInvoice\\nFROM Invoice i\\nJOIN Customer c ON i.CustomerId = c.CustomerId\\nWHERE c.Country = 'USA' AND c.Fax IS NULL AND i.InvoiceDate >= date('2003-01-01') AND i.InvoiceDate < date('2010-01-01')\"" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "query = chain.invoke(\n", - " {\n", - " \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n", - " }\n", - ")\n", - "query" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "70fff2fa-1f86-4f83-9fd2-e87a5234d329", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'[(6.632999999999998,)]'" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db.run(query)" - ] - }, - { - "cell_type": "markdown", - "id": "fc8af115-7c23-421a-8fd7-29bf1b6687a4", - "metadata": {}, - "source": [ - "## Human-in-the-loop\n", - "\n", - "In some cases our data is sensitive enough that we never want to execute a SQL query without a human approving it first. Head to the [Tool use: Human-in-the-loop](/docs/use_cases/tool_use/human_in_the_loop) page to learn how to add a human-in-the-loop to any tool, chain or agent.\n", - "\n", - "## Error handling\n", - "\n", - "At some point, the model will make a mistake and craft an invalid SQL query. Or an issue will arise with our database. Or the model API will go down. We'll want to add some error handling behavior to our chains and agents so that we fail gracefully in these situations, and perhaps even automatically recover. To learn about error handling with tools, head to the [Tool use: Error handling](/docs/use_cases/tool_use/tool_error_handling) page." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.1" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/docs/use_cases/sql/quickstart.ipynb b/docs/docs/use_cases/sql/quickstart.ipynb deleted file mode 100644 index 490700a45197b..0000000000000 --- a/docs/docs/use_cases/sql/quickstart.ipynb +++ /dev/null @@ -1,603 +0,0 @@ -{ - "cells": [ - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "---\n", - "sidebar_position: 0\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Quickstart\n", - "\n", - "In this guide we'll go over the basic ways to create a Q&A chain and agent over a SQL database. These systems will allow us to ask a question about the data in a SQL database and get back a natural language answer. The main difference between the two is that our agent can query the database in a loop as many time as it needs to answer the question.\n", - "\n", - "## ⚠️ Security note ⚠️\n", - "\n", - "Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n", - "\n", - "\n", - "## Architecture\n", - "\n", - "At a high-level, the steps of any SQL chain and agent are:\n", - "\n", - "1. **Convert question to SQL query**: Model converts user input to a SQL query.\n", - "2. **Execute SQL query**: Execute the SQL query.\n", - "3. **Answer the question**: Model responds to user input using the query results.\n", - "\n", - "\n", - "![sql_usecase.png](../../../static/img/sql_usecase.png)\n", - "\n", - "## Setup\n", - "\n", - "First, get required packages and set environment variables:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet langchain langchain-community langchain-openai" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We default to OpenAI models in this guide." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n", - "\n", - "# Uncomment the below to use LangSmith. Not required.\n", - "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n", - "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n", - "\n", - "* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n", - "* Run `sqlite3 Chinook.db`\n", - "* Run `.read Chinook_Sqlite.sql`\n", - "* Test `SELECT * FROM Artist LIMIT 10;`\n", - "\n", - "Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "sqlite\n", - "['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n" - ] - }, - { - "data": { - "text/plain": [ - "\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\"" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain_community.utilities import SQLDatabase\n", - "\n", - "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n", - "print(db.dialect)\n", - "print(db.get_usable_table_names())\n", - "db.run(\"SELECT * FROM Artist LIMIT 10;\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Great! We've got a SQL database that we can query. Now let's try hooking it up to an LLM.\n", - "\n", - "## Chain\n", - "\n", - "Let's create a simple chain that takes a question, turns it into a SQL query, executes the query, and uses the result to answer the original question.\n", - "\n", - "### Convert question to SQL query\n", - "\n", - "The first step in a SQL chain or agent is to take the user input and convert it to a SQL query. LangChain comes with a built-in chain for this: [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'SELECT COUNT(*) FROM Employee'" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain.chains import create_sql_query_chain\n", - "from langchain_openai import ChatOpenAI\n", - "\n", - "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n", - "chain = create_sql_query_chain(llm, db)\n", - "response = chain.invoke({\"question\": \"How many employees are there\"})\n", - "response" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can execute the query to make sure it's valid:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'[(8,)]'" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db.run(response)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can look at the [LangSmith trace](https://smith.langchain.com/public/c8fa52ea-be46-4829-bde2-52894970b830/r) to get a better understanding of what this chain is doing. We can also inpect the chain directly for its prompts. Looking at the prompt (below), we can see that it is:\n", - "\n", - "* Dialect-specific. In this case it references SQLite explicitly.\n", - "* Has definitions for all the available tables.\n", - "* Has three examples rows for each table.\n", - "\n", - "This technique is inspired by papers like [this](https://arxiv.org/pdf/2204.00498.pdf), which suggest showing examples rows and being explicit about tables improves performance. We can also in" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n", - "Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n", - "Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n", - "Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n", - "Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n", - "\n", - "Use the following format:\n", - "\n", - "Question: Question here\n", - "SQLQuery: SQL Query to run\n", - "SQLResult: Result of the SQLQuery\n", - "Answer: Final answer here\n", - "\n", - "Only use the following tables:\n", - "\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n", - "\n", - "Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n" - ] - } - ], - "source": [ - "chain.get_prompts()[0].pretty_print()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Execute SQL query\n", - "\n", - "Now that we've generated a SQL query, we'll want to execute it. **This is the most dangerous part of creating a SQL chain.** Consider carefully if it is OK to run automated queries over your data. Minimize the database connection permissions as much as possible. Consider adding a human approval step to you chains before query execution (see below).\n", - "\n", - "We can use the `QuerySQLDatabaseTool` to easily add query execution to our chain:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'[(8,)]'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n", - "\n", - "execute_query = QuerySQLDataBaseTool(db=db)\n", - "write_query = create_sql_query_chain(llm, db)\n", - "chain = write_query | execute_query\n", - "chain.invoke({\"question\": \"How many employees are there\"})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Answer the question\n", - "\n", - "Now that we've got a way to automatically generate and execute queries, we just need to combine the original question and SQL query result to generate a final answer. We can do this by passing question and result to the LLM once more:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'There are 8 employees.'" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from operator import itemgetter\n", - "\n", - "from langchain_core.output_parsers import StrOutputParser\n", - "from langchain_core.prompts import PromptTemplate\n", - "from langchain_core.runnables import RunnablePassthrough\n", - "\n", - "answer_prompt = PromptTemplate.from_template(\n", - " \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n", - "\n", - "Question: {question}\n", - "SQL Query: {query}\n", - "SQL Result: {result}\n", - "Answer: \"\"\"\n", - ")\n", - "\n", - "answer = answer_prompt | llm | StrOutputParser()\n", - "chain = (\n", - " RunnablePassthrough.assign(query=write_query).assign(\n", - " result=itemgetter(\"query\") | execute_query\n", - " )\n", - " | answer\n", - ")\n", - "\n", - "chain.invoke({\"question\": \"How many employees are there\"})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Next steps\n", - "\n", - "For more complex query-generation, we may want to create few-shot prompts or add query-checking steps. For advanced techniques like this and more check out:\n", - "\n", - "* [Prompting strategies](/docs/use_cases/sql/prompting): Advanced prompt engineering techniques.\n", - "* [Query checking](/docs/use_cases/sql/query_checking): Add query validation and error handling.\n", - "* [Large databses](/docs/use_cases/sql/large_db): Techniques for working with large databases." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Agents\n", - "\n", - "LangChain has an SQL Agent which provides a more flexible way of interacting with SQL databases. The main advantages of using the SQL Agent are:\n", - "\n", - "- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n", - "- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n", - "- It can answer questions that require multiple dependent queries.\n", - "\n", - "To initialize the agent, we use `create_sql_agent` function. This agent contains the `SQLDatabaseToolkit` which contains tools to: \n", - "\n", - "* Create and execute queries\n", - "* Check query syntax\n", - "* Retrieve table descriptions\n", - "* ... and more\n", - "\n", - "### Initializing agent" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.agent_toolkits import create_sql_agent\n", - "\n", - "agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_list_tables` with `{}`\n", - "\n", - "\n", - "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_schema` with `Invoice,Customer`\n", - "\n", - "\n", - "\u001b[0m\u001b[33;1m\u001b[1;3m\n", - "CREATE TABLE \"Customer\" (\n", - "\t\"CustomerId\" INTEGER NOT NULL, \n", - "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n", - "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", - "\t\"Company\" NVARCHAR(80), \n", - "\t\"Address\" NVARCHAR(70), \n", - "\t\"City\" NVARCHAR(40), \n", - "\t\"State\" NVARCHAR(40), \n", - "\t\"Country\" NVARCHAR(40), \n", - "\t\"PostalCode\" NVARCHAR(10), \n", - "\t\"Phone\" NVARCHAR(24), \n", - "\t\"Fax\" NVARCHAR(24), \n", - "\t\"Email\" NVARCHAR(60) NOT NULL, \n", - "\t\"SupportRepId\" INTEGER, \n", - "\tPRIMARY KEY (\"CustomerId\"), \n", - "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Customer table:\n", - "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n", - "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n", - "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n", - "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n", - "*/\n", - "\n", - "\n", - "CREATE TABLE \"Invoice\" (\n", - "\t\"InvoiceId\" INTEGER NOT NULL, \n", - "\t\"CustomerId\" INTEGER NOT NULL, \n", - "\t\"InvoiceDate\" DATETIME NOT NULL, \n", - "\t\"BillingAddress\" NVARCHAR(70), \n", - "\t\"BillingCity\" NVARCHAR(40), \n", - "\t\"BillingState\" NVARCHAR(40), \n", - "\t\"BillingCountry\" NVARCHAR(40), \n", - "\t\"BillingPostalCode\" NVARCHAR(10), \n", - "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", - "\tPRIMARY KEY (\"InvoiceId\"), \n", - "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from Invoice table:\n", - "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", - "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", - "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", - "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", - "*/\u001b[0m\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`\n", - "responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n", - "\n", - "Here is the SQL query:\n", - "\n", - "```sql\n", - "SELECT c.Country, SUM(i.Total) AS TotalSales\n", - "FROM Invoice i\n", - "JOIN Customer c ON i.CustomerId = c.CustomerId\n", - "GROUP BY c.Country\n", - "ORDER BY TotalSales DESC\n", - "LIMIT 10;\n", - "```\n", - "\n", - "Now, I will execute this query to get the total sales per country.\n", - "\n", - "\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n", - "\n", - "1. USA: $523.06\n", - "2. Canada: $303.96\n", - "3. France: $195.10\n", - "4. Brazil: $190.10\n", - "5. Germany: $156.48\n", - "6. United Kingdom: $112.86\n", - "7. Czech Republic: $90.24\n", - "8. Portugal: $77.24\n", - "9. India: $75.26\n", - "10. Chile: $46.62\n", - "\n", - "To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': \"List the total sales per country. Which country's customers spent the most?\",\n", - " 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_executor.invoke(\n", - " {\n", - " \"input\": \"List the total sales per country. Which country's customers spent the most?\"\n", - " }\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_list_tables` with `{}`\n", - "\n", - "\n", - "\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n", - "Invoking: `sql_db_schema` with `PlaylistTrack`\n", - "\n", - "\n", - "\u001b[0m\u001b[33;1m\u001b[1;3m\n", - "CREATE TABLE \"PlaylistTrack\" (\n", - "\t\"PlaylistId\" INTEGER NOT NULL, \n", - "\t\"TrackId\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", - "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", - "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", - ")\n", - "\n", - "/*\n", - "3 rows from PlaylistTrack table:\n", - "PlaylistId\tTrackId\n", - "1\t3402\n", - "1\t3389\n", - "1\t3390\n", - "*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n", - "\n", - "Here is the schema of the `PlaylistTrack` table:\n", - "\n", - "```\n", - "CREATE TABLE \"PlaylistTrack\" (\n", - "\t\"PlaylistId\" INTEGER NOT NULL, \n", - "\t\"TrackId\" INTEGER NOT NULL, \n", - "\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n", - "\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n", - "\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n", - ")\n", - "```\n", - "\n", - "The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n", - "\n", - "Here are three sample rows from the `PlaylistTrack` table:\n", - "\n", - "```\n", - "PlaylistId TrackId\n", - "1 3402\n", - "1 3389\n", - "1 3390\n", - "```\n", - "\n", - "Please let me know if there is anything else I can help with.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "{'input': 'Describe the playlisttrack table',\n", - " 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "agent_executor.invoke({\"input\": \"Describe the playlisttrack table\"})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Next steps\n", - "\n", - "For more on how to use and customize agents head to the [Agents](/docs/use_cases/sql/agents) page." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "poetry-venv", - "language": "python", - "name": "poetry-venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/libs/community/langchain_community/storage/__init__.py b/libs/community/langchain_community/storage/__init__.py index 9a73d49110afc..d75b497bf584f 100644 --- a/libs/community/langchain_community/storage/__init__.py +++ b/libs/community/langchain_community/storage/__init__.py @@ -31,6 +31,9 @@ from langchain_community.storage.redis import ( RedisStore, ) + from langchain_community.storage.sql import ( + SQLStore, + ) from langchain_community.storage.upstash_redis import ( UpstashRedisByteStore, UpstashRedisStore, @@ -42,6 +45,7 @@ "CassandraByteStore", "MongoDBStore", "RedisStore", + "SQLStore", "UpstashRedisByteStore", "UpstashRedisStore", ] @@ -52,6 +56,7 @@ "CassandraByteStore": "langchain_community.storage.cassandra", "MongoDBStore": "langchain_community.storage.mongodb", "RedisStore": "langchain_community.storage.redis", + "SQLStore": "langchain_community.storage.sql", "UpstashRedisByteStore": "langchain_community.storage.upstash_redis", "UpstashRedisStore": "langchain_community.storage.upstash_redis", } diff --git a/libs/community/langchain_community/storage/sql_docstore.py b/libs/community/langchain_community/storage/sql_docstore.py deleted file mode 100644 index 5a1adfffa0fd3..0000000000000 --- a/libs/community/langchain_community/storage/sql_docstore.py +++ /dev/null @@ -1,236 +0,0 @@ -import contextlib -from pathlib import Path -from typing import ( - Any, - Dict, - Generator, - Iterator, - List, - Optional, - Sequence, - Tuple, - Union, -) - -from langchain_core.stores import BaseStore -from sqlalchemy import ( - Column, - Engine, - PickleType, - and_, - create_engine, -) -from sqlalchemy.ext.asyncio import ( - AsyncEngine, - AsyncSession, - async_sessionmaker, - create_async_engine, -) -from sqlalchemy.orm import ( - Mapped, - Session, - declarative_base, - mapped_column, - scoped_session, - sessionmaker, -) - -Base = declarative_base() - - -def items_equal(x: Any, y: Any) -> bool: - return x == y - - -class Value(Base): # type: ignore[valid-type,misc] - """Table used to save values.""" - - # ATTENTION: - # Prior to modifying this table, please determine whether - # we should create migrations for this table to make sure - # users do not experience data loss. - __tablename__ = "docstore" - - namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) - key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) - # value: Mapped[Any] = Column(type_=PickleType, index=False, nullable=False) - value: Any = Column("earthquake", PickleType(comparator=items_equal)) - - -# This is a fix of original SQLStore. -# This can will be removed when a PR will be merged. -class SQLStore(BaseStore[str, bytes]): - """BaseStore interface that works on an SQL database. - - Examples: - Create a SQLStore instance and perform operations on it: - - .. code-block:: python - - from langchain_rag.storage import SQLStore - - # Instantiate the SQLStore with the root path - sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:") - - # Set values for keys - sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) - - # Get values for keys - values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] - - # Delete keys - sql_store.mdelete(["key1"]) - - # Iterate over keys - for key in sql_store.yield_keys(): - print(key) - - """ - - def __init__( - self, - *, - namespace: str, - db_url: Optional[Union[str, Path]] = None, - engine: Optional[Union[Engine, AsyncEngine]] = None, - engine_kwargs: Optional[Dict[str, Any]] = None, - async_mode: bool = False, - ): - if db_url is None and engine is None: - raise ValueError("Must specify either db_url or engine") - - if db_url is not None and engine is not None: - raise ValueError("Must specify either db_url or engine, not both") - - _engine: Union[Engine, AsyncEngine] - if db_url: - if async_mode: - _engine = create_async_engine(url=str(db_url), **(engine_kwargs or {})) - else: - _engine = create_engine(url=str(db_url), **(engine_kwargs or {})) - elif engine: - _engine = engine - - else: - raise AssertionError("Something went wrong with configuration of engine.") - - _session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] - if isinstance(_engine, AsyncEngine): - _session_factory = async_sessionmaker(bind=_engine) - else: - _session_factory = sessionmaker(bind=_engine) - - self.engine = _engine - self.dialect = _engine.dialect.name - self.session_factory = _session_factory - self.namespace = namespace - - def create_schema(self) -> None: - Base.metadata.create_all(self.engine) - - def drop(self) -> None: - Base.metadata.drop_all(bind=self.engine.connect()) - - # async def amget(self, keys: Sequence[K]) -> List[Optional[V]]: - # result = {} - # async with self._make_session() as session: - # async with session.begin(): - # for v in session.query(Value).filter( - # and_( - # Value.key.in_(keys), - # Value.namespace == self.namespace, - # ) - # ): - # result[v.key] = v.value - # return [result.get(key) for key in keys] - - def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: - result = {} - - with self._make_session() as session: - for v in session.query(Value).filter( # type: ignore - and_( - Value.key.in_(keys), - Value.namespace == self.namespace, - ) - ): - result[v.key] = v.value - return [result.get(key) for key in keys] - - # async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: - # async with self._make_session() as session: - # async with session.begin(): - # # await self._amdetete([key for key, _ in key_value_pairs], session) - # session.add_all([Value(namespace=self.namespace, - # key=k, - # value=v) for k, v in key_value_pairs]) - # session.commit() - - def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: - # try: - with self._make_session() as session: - self._mdetete([key for key, _ in key_value_pairs], session) - session.add_all( - [ - Value(namespace=self.namespace, key=k, value=v) - for k, v in key_value_pairs - ] - ) - - def _mdetete(self, keys: Sequence[str], session: Session) -> None: - with session.begin(): - session.query(Value).filter( # type: ignore - and_( - Value.key.in_(keys), - Value.namespace == self.namespace, - ) - ).delete() - - # async def _amdetete(self, keys: Sequence[str], session: Session) -> None: - # await session.query(Value).filter( - # and_( - # Value.key.in_(keys), - # Value.namespace == self.namespace, - # ) - # ).delete() - - def mdelete(self, keys: Sequence[str]) -> None: - with self._make_session() as session: - self._mdetete(keys, session) - session.commit() - - # async def amdelete(self, keys: Sequence[str]) -> None: - # with self._make_session() as session: - # await self._mdelete(keys, session) - # session.commit() - - def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: - with self._make_session() as session: - for v in session.query(Value).filter( # type: ignore - Value.namespace == self.namespace - ): - yield str(v.key) - session.close() - - @contextlib.contextmanager - def _make_session(self) -> Generator[Session, None, None]: - """Create a session and close it after use.""" - - if isinstance(self.session_factory, async_sessionmaker): - raise AssertionError("This method is not supported for async engines.") - - session = scoped_session(self.session_factory)() - try: - yield session - finally: - session.commit() - - # @contextlib.asynccontextmanager - # async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]: - # """Create a session and close it after use.""" - # - # if not isinstance(self.session_factory, async_sessionmaker): - # raise AssertionError("This method is not supported for sync engines.") - # - # async with self.session_factory() as session: - # yield session diff --git a/libs/community/tests/unit_tests/storage/test_imports.py b/libs/community/tests/unit_tests/storage/test_imports.py index 750b7c5a3e2f7..791f0298cc5e7 100644 --- a/libs/community/tests/unit_tests/storage/test_imports.py +++ b/libs/community/tests/unit_tests/storage/test_imports.py @@ -5,6 +5,7 @@ "AstraDBByteStore", "CassandraByteStore", "MongoDBStore", + "SQLStore", "RedisStore", "UpstashRedisByteStore", "UpstashRedisStore", diff --git a/libs/community/tests/unit_tests/storage/test_sql.py b/libs/community/tests/unit_tests/storage/test_sql.py new file mode 100644 index 0000000000000..3502450a6b6c3 --- /dev/null +++ b/libs/community/tests/unit_tests/storage/test_sql.py @@ -0,0 +1,89 @@ +from typing import AsyncGenerator, Generator, cast + +import pytest +from langchain.storage._lc_store import create_kv_docstore, create_lc_store +from langchain_core.documents import Document +from langchain_core.stores import BaseStore + +from langchain_community.storage.sql import SQLStore + + +@pytest.fixture +def sql_store() -> Generator[SQLStore, None, None]: + store = SQLStore(namespace="test", db_url="sqlite://") + store.create_schema() + yield store + + +@pytest.fixture +async def async_sql_store() -> AsyncGenerator[SQLStore, None]: + store = SQLStore(namespace="test", db_url="sqlite+aiosqlite://", async_mode=True) + await store.acreate_schema() + yield store + + +def test_create_lc_store(sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore: BaseStore[str, Document] = cast( + BaseStore[str, Document], create_lc_store(sql_store) + ) + docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) + fetched_doc = docstore.mget(["key1"])[0] + assert fetched_doc is not None + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +def test_create_kv_store(sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore = create_kv_docstore(sql_store) + docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) + fetched_doc = docstore.mget(["key1"])[0] + assert isinstance(fetched_doc, Document) + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +@pytest.mark.asyncio +async def test_async_create_kv_store(async_sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore = create_kv_docstore(async_sql_store) + await docstore.amset( + [("key1", Document(page_content="hello", metadata={"key": "value"}))] + ) + fetched_doc = (await docstore.amget(["key1"]))[0] + assert isinstance(fetched_doc, Document) + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +def test_sample_sql_docstore(sql_store: SQLStore) -> None: + # Set values for keys + sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + assert values == [b"value1", b"value2"] + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + assert [key for key in sql_store.yield_keys()] == ["key2"] + + +@pytest.mark.asyncio +async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None: + # Set values for keys + await async_sql_store.amset([("key1", b"value1"), ("key2", b"value2")]) + # sql_store.mset([("key1", "value1"), ("key2", "value2")]) + + # Get values for keys + values = await async_sql_store.amget( + ["key1", "key2"] + ) # Returns [b"value1", b"value2"] + assert values == [b"value1", b"value2"] + # Delete keys + await async_sql_store.amdelete(["key1"]) + + # Iterate over keys + assert [key async for key in async_sql_store.ayield_keys()] == ["key2"] diff --git a/libs/community/tests/unit_tests/test_dependencies.py b/libs/community/tests/unit_tests/test_dependencies.py index ce4389d8313ab..c2ad95e85e0a2 100644 --- a/libs/community/tests/unit_tests/test_dependencies.py +++ b/libs/community/tests/unit_tests/test_dependencies.py @@ -74,6 +74,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: assert test_group_deps == sorted( [ + "aiosqlite", "duckdb-engine", "freezegun", "langchain-core", diff --git a/libs/core/tests/unit_tests/_api/test_deprecation.py b/libs/core/tests/unit_tests/_api/test_deprecation.py index dbfd2554aabb5..fc05c002575da 100644 --- a/libs/core/tests/unit_tests/_api/test_deprecation.py +++ b/libs/core/tests/unit_tests/_api/test_deprecation.py @@ -155,30 +155,6 @@ async def test_deprecated_async_function() -> None: assert inspect.iscoroutinefunction(deprecated_async_function) - assert not inspect.iscoroutinefunction(deprecated_function) - - -@pytest.mark.asyncio -async def test_deprecated_async_function() -> None: - """Test deprecated async function.""" - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always") - assert ( - await deprecated_async_function() == "This is a deprecated async function." - ) - assert len(warning_list) == 1 - warning = warning_list[0].message - assert str(warning) == ( - "The function `deprecated_async_function` was deprecated " - "in LangChain 2.0.0 and will be removed in 3.0.0" - ) - - doc = deprecated_function.__doc__ - assert isinstance(doc, str) - assert doc.startswith("[*Deprecated*] original doc") - - assert inspect.iscoroutinefunction(deprecated_async_function) - def test_deprecated_method() -> None: """Test deprecated method.""" @@ -222,31 +198,6 @@ async def test_deprecated_async_method() -> None: assert inspect.iscoroutinefunction(obj.deprecated_async_method) - assert not inspect.iscoroutinefunction(obj.deprecated_method) - - -@pytest.mark.asyncio -async def test_deprecated_async_method() -> None: - """Test deprecated async method.""" - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always") - obj = ClassWithDeprecatedMethods() - assert ( - await obj.deprecated_async_method() == "This is a deprecated async method." - ) - assert len(warning_list) == 1 - warning = warning_list[0].message - assert str(warning) == ( - "The function `deprecated_async_method` was deprecated in " - "LangChain 2.0.0 and will be removed in 3.0.0" - ) - - doc = obj.deprecated_method.__doc__ - assert isinstance(doc, str) - assert doc.startswith("[*Deprecated*] original doc") - - assert inspect.iscoroutinefunction(obj.deprecated_async_method) - def test_deprecated_classmethod() -> None: """Test deprecated classmethod.""" From 5f06ce3516826c29922cfd3e9a812ce24a5d177d Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Fri, 7 Jun 2024 12:29:52 +0200 Subject: [PATCH 06/13] Add unit test --- .../langchain_community/storage/sql.py | 265 ++++++++++++++++++ libs/community/poetry.lock | 93 +++--- libs/community/pyproject.toml | 1 + .../integration_tests/storage/test_sql.py | 165 +++++++++++ 4 files changed, 473 insertions(+), 51 deletions(-) create mode 100644 libs/community/langchain_community/storage/sql.py create mode 100644 libs/community/tests/integration_tests/storage/test_sql.py diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py new file mode 100644 index 0000000000000..86a9b8d50286a --- /dev/null +++ b/libs/community/langchain_community/storage/sql.py @@ -0,0 +1,265 @@ +import contextlib +from pathlib import Path +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Generator, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from langchain_core.stores import BaseStore, V +from sqlalchemy import ( + Column, + Engine, + PickleType, + and_, + create_engine, + delete, + select, +) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Mapped, + Session, + declarative_base, + mapped_column, + sessionmaker, +) + +Base = declarative_base() + + +def items_equal(x: Any, y: Any) -> bool: + return x == y + + +class Value(Base): # type: ignore[valid-type,misc] + """Table used to save values.""" + + # ATTENTION: + # Prior to modifying this table, please determine whether + # we should create migrations for this table to make sure + # users do not experience data loss. + __tablename__ = "docstore" + + namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + value: Mapped[Any] = Column(type_=PickleType, index=False, nullable=False) + + +# This is a fix of original SQLStore. +# This can will be removed when a PR will be merged. +class SQLStore(BaseStore[str, bytes]): + """BaseStore interface that works on an SQL database. + + Examples: + Create a SQLStore instance and perform operations on it: + + .. code-block:: python + + from langchain_rag.storage import SQLStore + + # Instantiate the SQLStore with the root path + sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:") + + # Set values for keys + sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + for key in sql_store.yield_keys(): + print(key) + + """ + + def __init__( + self, + *, + namespace: str, + db_url: Optional[Union[str, Path]] = None, + engine: Optional[Union[Engine, AsyncEngine]] = None, + engine_kwargs: Optional[Dict[str, Any]] = None, + async_mode: Optional[bool] = None, + ): + if db_url is None and engine is None: + raise ValueError("Must specify either db_url or engine") + + if db_url is not None and engine is not None: + raise ValueError("Must specify either db_url or engine, not both") + + _engine: Union[Engine, AsyncEngine] + if db_url: + if async_mode is None: + async_mode = False + if async_mode: + _engine = create_async_engine( + url=str(db_url), + **(engine_kwargs or {}), + ) + else: + _engine = create_engine(url=str(db_url), **(engine_kwargs or {})) + elif engine: + _engine = engine + + else: + raise AssertionError("Something went wrong with configuration of engine.") + + _session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] + if isinstance(_engine, AsyncEngine): + self.async_mode = True + _session_maker = async_sessionmaker(bind=_engine) + else: + self.async_mode = False + _session_maker = sessionmaker(bind=_engine) + + self.engine = _engine + self.dialect = _engine.dialect.name + self.session_maker = _session_maker + self.namespace = namespace + + def create_schema(self) -> None: + Base.metadata.create_all(self.engine) + + async def acreate_schema(self) -> None: + assert isinstance(self.engine, AsyncEngine) + async with self.engine.begin() as session: + await session.run_sync(Base.metadata.create_all) + + def drop(self) -> None: + Base.metadata.drop_all(bind=self.engine.connect()) + + async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: + assert isinstance(self.engine, AsyncEngine) + result: Dict[str, V] = {} + async with self._make_async_session() as session: + stmt = select(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + for v in await session.scalars(stmt): + result[v.key] = v.value + return [result.get(key) for key in keys] + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + result = {} + + with self._make_sync_session() as session: + stmt = select(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + for v in session.scalars(stmt): + result[v.key] = v.value + return [result.get(key) for key in keys] + + async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + async with self._make_async_session() as session: + await self._amdelete([key for key, _ in key_value_pairs], session) + session.add_all( + [ + Value(namespace=self.namespace, key=k, value=v) + for k, v in key_value_pairs + ] + ) + await session.commit() + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + values: Dict[str, bytes] = dict(key_value_pairs) + with self._make_sync_session() as session: + self._mdelete(list(values.keys()), session) + session.add_all( + [ + Value(namespace=self.namespace, key=k, value=v) + for k, v in values.items() + ] + ) + session.commit() + + def _mdelete(self, keys: Sequence[str], session: Session) -> None: + stmt = delete(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + session.execute(stmt) + + async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None: + stmt = delete(Value).filter( + and_( + Value.key.in_(keys), + Value.namespace == self.namespace, + ) + ) + await session.execute(stmt) + + def mdelete(self, keys: Sequence[str]) -> None: + with self._make_sync_session() as session: + self._mdelete(keys, session) + session.commit() + + async def amdelete(self, keys: Sequence[str]) -> None: + async with self._make_async_session() as session: + await self._amdelete(keys, session) + await session.commit() + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + with self._make_sync_session() as session: + for v in session.query(Value).filter( # type: ignore + Value.namespace == self.namespace + ): + if str(v.key).startswith(prefix or ""): + yield str(v.key) + session.close() + + async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: + async with self._make_async_session() as session: + stmt = select(Value).filter(Value.namespace == self.namespace) + for v in await session.scalars(stmt): + if str(v.key).startswith(prefix or ""): + yield str(v.key) + await session.close() + + @contextlib.contextmanager + def _make_sync_session(self) -> Generator[Session, None, None]: + """Make an async session.""" + if self.async_mode: + raise ValueError( + "Attempting to use a sync method in when async mode is turned on. " + "Please use the corresponding async method instead." + ) + with cast(Session, self.session_maker()) as session: + yield cast(Session, session) + + @contextlib.asynccontextmanager + async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: + """Make an async session.""" + if not self.async_mode: + raise ValueError( + "Attempting to use an async method in when sync mode is turned on. " + "Please use the corresponding async method instead." + ) + async with cast(AsyncSession, self.session_maker()) as session: + yield cast(AsyncSession, session) diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index fc08f4ac48a65..f91268cb6d659 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -110,6 +110,21 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiosqlite" +version = "0.19.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"}, + {file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"}, +] + +[package.extras] +dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"] +docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"] + [[package]] name = "annotated-types" version = "0.7.0" @@ -2116,7 +2131,7 @@ files = [ [[package]] name = "langchain" -version = "0.2.2" +version = "0.2.3" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -2125,6 +2140,7 @@ develop = true [package.dependencies] aiohttp = "^3.8.3" +async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} langchain-core = "^0.2.0" langchain-text-splitters = "^0.2.0" langsmith = "^0.1.17" @@ -2141,7 +2157,7 @@ url = "../langchain" [[package]] name = "langchain-core" -version = "0.2.4" +version = "0.2.5" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -2150,7 +2166,7 @@ develop = true [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.66" +langsmith = "^0.1.75" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -2178,13 +2194,13 @@ url = "../text-splitters" [[package]] name = "langsmith" -version = "0.1.73" +version = "0.1.75" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.73-py3-none-any.whl", hash = "sha256:38bfcce2cfcf0b2da2e9628b903c9e768e1ce59d450e8a584514c1638c595e93"}, - {file = "langsmith-0.1.73.tar.gz", hash = "sha256:0055471cb1fddb76ec65499716764ad0b0314affbdf33ff1f72ad5e2d6a3b224"}, + {file = "langsmith-0.1.75-py3-none-any.whl", hash = "sha256:d08b08dd6b3fa4da170377f95123d77122ef4c52999d10fff4ae08ff70d07aed"}, + {file = "langsmith-0.1.75.tar.gz", hash = "sha256:61274e144ea94c297dd78ce03e6dfae18459fe9bd8ab5094d61a0c4816561279"}, ] [package.dependencies] @@ -3022,8 +3038,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5084,20 +5100,6 @@ files = [ cryptography = ">=35.0.0" types-pyOpenSSL = "*" -[[package]] -name = "types-requests" -version = "2.31.0.6" -description = "Typing stubs for requests" -optional = false -python-versions = ">=3.7" -files = [ - {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, - {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, -] - -[package.dependencies] -types-urllib3 = "*" - [[package]] name = "types-requests" version = "2.32.0.20240602" @@ -5134,17 +5136,6 @@ files = [ {file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"}, ] -[[package]] -name = "types-urllib3" -version = "1.26.25.14" -description = "Typing stubs for urllib3" -optional = false -python-versions = "*" -files = [ - {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, - {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, -] - [[package]] name = "typing-extensions" version = "4.12.1" @@ -5196,22 +5187,6 @@ files = [ [package.extras] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] -[[package]] -name = "urllib3" -version = "1.26.18" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -files = [ - {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, - {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, -] - -[package.extras] -brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] - [[package]] name = "urllib3" version = "2.2.1" @@ -5229,6 +5204,23 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "vcrpy" +version = "4.3.0" +description = "Automatically mock your HTTP interactions to simplify and speed up testing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "vcrpy-4.3.0-py2.py3-none-any.whl", hash = "sha256:8fbd4be412e8a7f35f623dd61034e6380a1c8dbd0edf6e87277a3289f6e98093"}, + {file = "vcrpy-4.3.0.tar.gz", hash = "sha256:49c270ce67e826dba027d83e20d25b67a5885487697e97bca6dbdf53d750a0ac"}, +] + +[package.dependencies] +PyYAML = "*" +six = ">=1.5" +wrapt = "*" +yarl = "*" + [[package]] name = "vcrpy" version = "6.0.1" @@ -5241,7 +5233,6 @@ files = [ [package.dependencies] PyYAML = "*" -urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\" or python_version < \"3.10\""} wrapt = "*" yarl = "*" @@ -5651,4 +5642,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "21ead159a299fcc5cc0a9038ddcee5b4f355c893f23ef80456b72941ad3122fd" +content-hash = "7b0a65085916f44b5de275c204483691a5433027647189fe8a91367eec56a491" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 3d4ec6797b183..e27ef2652257e 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -31,6 +31,7 @@ optional = true pytest = "^7.3.0" pytest-cov = "^4.1.0" pytest-dotenv = "^0.5.2" +aiosqlite = "^0.19.0" duckdb-engine = "^0.11.0" pytest-watcher = "^0.2.6" freezegun = "^1.2.2" diff --git a/libs/community/tests/integration_tests/storage/test_sql.py b/libs/community/tests/integration_tests/storage/test_sql.py new file mode 100644 index 0000000000000..e655a08a8a17f --- /dev/null +++ b/libs/community/tests/integration_tests/storage/test_sql.py @@ -0,0 +1,165 @@ +"""Implement integration tests for Redis storage.""" + +import pickle + +import pytest +from sqlalchemy import Engine, create_engine, text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from langchain_community.storage import SQLStore + +# if typing.TYPE_CHECKING: +# from sqlalchemy import Engine, create_engine + +pytest.importorskip("sqlalchemy") + + +@pytest.fixture +def sql_engine() -> Engine: + """Yield redis client.""" + return create_engine(url="sqlite://", echo=True) + + +@pytest.fixture +def sql_aengine() -> Engine: + """Yield redis client.""" + return create_async_engine(url="sqlite+aiosqlite:///:memory:", echo=True) + + +def test_mget(sql_engine: Engine) -> None: + """Test mget method.""" + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + keys = ["key1", "key2"] + with sql_engine.connect() as session: + session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key1',:value)") + .bindparams(value=pickle.dumps(b"value1")), + ) + session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key2',:value)") + .bindparams(value=pickle.dumps(b"value2")), + ) + session.commit() + + result = store.mget(keys) + assert result == [b"value1", b"value2"] + + +@pytest.mark.asyncio +async def test_amget(sql_aengine: AsyncEngine) -> None: + """Test mget method.""" + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + keys = ["key1", "key2"] + async with sql_aengine.connect() as session: + await session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key1',:value)") + .bindparams(value=pickle.dumps(b"value1")), + ) + await session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key2',:value)") + .bindparams(value=pickle.dumps(b"value2")), + ) + await session.commit() + + result = await store.amget(keys) + assert result == [b"value1", b"value2"] + + +def test_mset(sql_engine: Engine) -> None: + """Test that multiple keys can be set.""" + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + store.mset(key_value_pairs) + + with sql_engine.connect() as session: + result = session.exec_driver_sql("select * from docstore") + assert result.keys() == ["namespace", "key", "value"] + data = [(row[0], row[1]) for row in result] + assert data == [("test", "key1"), ("test", "key2")] + session.commit() + +@pytest.mark.asyncio +async def test_amset(sql_aengine: Engine) -> None: + """Test that multiple keys can be set.""" + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + await store.amset(key_value_pairs) + + async with sql_aengine.connect() as session: + result = await session.exec_driver_sql("select * from docstore") + assert result.keys() == ["namespace", "key", "value"] + data = [(row[0], row[1]) for row in result] + assert data == [("test", "key1"), ("test", "key2")] + session.commit() + +def test_mdelete(sql_engine: Engine) -> None: + """Test that deletion works as expected.""" + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + keys = ["key1", "key2"] + with sql_engine.connect() as session: + session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key1',:value)") + .bindparams(value=pickle.dumps(b"value1")), + ) + session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key2',:value)") + .bindparams(value=pickle.dumps(b"value2")), + ) + session.commit() + store.mdelete(keys) + with sql_engine.connect() as session: + result = session.exec_driver_sql("select * from docstore") + assert result.keys() == ["namespace", "key", "value"] + data = [row for row in result] + assert data == [] + session.commit() + +@pytest.mark.asyncio +async def test_amdelete(sql_aengine: Engine) -> None: + """Test that deletion works as expected.""" + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + keys = ["key1", "key2"] + async with sql_aengine.connect() as session: + await session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key1',:value)") + .bindparams(value=pickle.dumps(b"value1")), + ) + await session.execute(text("insert into docstore ('namespace', 'key', 'value') " + "values('test','key2',:value)") + .bindparams(value=pickle.dumps(b"value2")), + ) + await session.commit() + await store.amdelete(keys) + async with sql_aengine.connect() as session: + result = await session.exec_driver_sql("select * from docstore") + assert result.keys() == ["namespace", "key", "value"] + data = [row for row in result] + assert data == [] + await session.commit() + + +def test_yield_keys(sql_engine: Engine) -> None: + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + store.mset(key_value_pairs) + assert sorted(store.yield_keys()) == ["key1", "key2"] + assert sorted(store.yield_keys(prefix="key")) == ["key1", "key2"] + assert sorted(store.yield_keys(prefix="lang")) == [] + + +@pytest.mark.asyncio +async def test_yield_keys(sql_aengine: Engine) -> None: + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + await store.amset(key_value_pairs) + assert sorted([k async for k in store.ayield_keys()]) == ["key1", "key2"] + assert sorted([k async for k in store.ayield_keys(prefix="key")]) == ["key1", "key2"] + assert sorted([k async for k in store.ayield_keys(prefix="lang")]) == [] From e8f1f57640b42fdcf5ea78ddc97891446aef9637 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 5 Jun 2024 11:18:09 -0400 Subject: [PATCH 07/13] Revert changes in tests_dependencies --- libs/community/tests/unit_tests/test_dependencies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/community/tests/unit_tests/test_dependencies.py b/libs/community/tests/unit_tests/test_dependencies.py index c2ad95e85e0a2..ce4389d8313ab 100644 --- a/libs/community/tests/unit_tests/test_dependencies.py +++ b/libs/community/tests/unit_tests/test_dependencies.py @@ -74,7 +74,6 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: assert test_group_deps == sorted( [ - "aiosqlite", "duckdb-engine", "freezegun", "langchain-core", From 261271410032679b3a45c6d04f94859939e250c0 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 5 Jun 2024 11:20:32 -0400 Subject: [PATCH 08/13] Update markers so aiosqlite is handled as an optional --- libs/community/tests/unit_tests/storage/test_sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/community/tests/unit_tests/storage/test_sql.py b/libs/community/tests/unit_tests/storage/test_sql.py index 3502450a6b6c3..084f0e2d19089 100644 --- a/libs/community/tests/unit_tests/storage/test_sql.py +++ b/libs/community/tests/unit_tests/storage/test_sql.py @@ -44,7 +44,7 @@ def test_create_kv_store(sql_store: SQLStore) -> None: assert fetched_doc.metadata == {"key": "value"} -@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") async def test_async_create_kv_store(async_sql_store: SQLStore) -> None: """Test that a docstore is created from a base store.""" docstore = create_kv_docstore(async_sql_store) @@ -71,7 +71,7 @@ def test_sample_sql_docstore(sql_store: SQLStore) -> None: assert [key for key in sql_store.yield_keys()] == ["key2"] -@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None: # Set values for keys await async_sql_store.amset([("key1", b"value1"), ("key2", b"value2")]) From d06ff90011f084e0eff24bc6f9bf60344cb16be9 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Fri, 7 Jun 2024 13:12:54 +0200 Subject: [PATCH 09/13] Add unit test --- .../langchain_community/storage/sql.py | 15 ++- .../integration_tests/storage/test_sql.py | 115 +++++++++++------- 2 files changed, 75 insertions(+), 55 deletions(-) diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py index 86a9b8d50286a..0767fce9b1899 100644 --- a/libs/community/langchain_community/storage/sql.py +++ b/libs/community/langchain_community/storage/sql.py @@ -15,11 +15,10 @@ cast, ) -from langchain_core.stores import BaseStore, V +from langchain_core.stores import BaseStore from sqlalchemy import ( - Column, Engine, - PickleType, + LargeBinary, and_, create_engine, delete, @@ -53,11 +52,11 @@ class Value(Base): # type: ignore[valid-type,misc] # Prior to modifying this table, please determine whether # we should create migrations for this table to make sure # users do not experience data loss. - __tablename__ = "docstore" + __tablename__ = "langchain_key_value_stores" namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) - value: Mapped[Any] = Column(type_=PickleType, index=False, nullable=False) + value = mapped_column(LargeBinary, index=False, nullable=False) # This is a fix of original SQLStore. @@ -146,9 +145,9 @@ async def acreate_schema(self) -> None: def drop(self) -> None: Base.metadata.drop_all(bind=self.engine.connect()) - async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: + async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: assert isinstance(self.engine, AsyncEngine) - result: Dict[str, V] = {} + result: Dict[str, bytes] = {} async with self._make_async_session() as session: stmt = select(Value).filter( and_( @@ -174,7 +173,7 @@ def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: result[v.key] = v.value return [result.get(key) for key in keys] - async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: async with self._make_async_session() as session: await self._amdelete([key for key, _ in key_value_pairs], session) session.add_all( diff --git a/libs/community/tests/integration_tests/storage/test_sql.py b/libs/community/tests/integration_tests/storage/test_sql.py index e655a08a8a17f..a454029b86cdf 100644 --- a/libs/community/tests/integration_tests/storage/test_sql.py +++ b/libs/community/tests/integration_tests/storage/test_sql.py @@ -1,16 +1,11 @@ """Implement integration tests for Redis storage.""" -import pickle - import pytest from sqlalchemy import Engine, create_engine, text from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from langchain_community.storage import SQLStore -# if typing.TYPE_CHECKING: -# from sqlalchemy import Engine, create_engine - pytest.importorskip("sqlalchemy") @@ -21,7 +16,7 @@ def sql_engine() -> Engine: @pytest.fixture -def sql_aengine() -> Engine: +def sql_aengine() -> AsyncEngine: """Yield redis client.""" return create_async_engine(url="sqlite+aiosqlite:///:memory:", echo=True) @@ -32,14 +27,18 @@ def test_mget(sql_engine: Engine) -> None: store.create_schema() keys = ["key1", "key2"] with sql_engine.connect() as session: - session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key1',:value)") - .bindparams(value=pickle.dumps(b"value1")), - ) - session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key2',:value)") - .bindparams(value=pickle.dumps(b"value2")), - ) + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) session.commit() result = store.mget(keys) @@ -53,14 +52,18 @@ async def test_amget(sql_aengine: AsyncEngine) -> None: await store.acreate_schema() keys = ["key1", "key2"] async with sql_aengine.connect() as session: - await session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key1',:value)") - .bindparams(value=pickle.dumps(b"value1")), - ) - await session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key2',:value)") - .bindparams(value=pickle.dumps(b"value2")), - ) + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) await session.commit() result = await store.amget(keys) @@ -75,14 +78,15 @@ def test_mset(sql_engine: Engine) -> None: store.mset(key_value_pairs) with sql_engine.connect() as session: - result = session.exec_driver_sql("select * from docstore") + result = session.exec_driver_sql("select * from langchain_key_value_stores") assert result.keys() == ["namespace", "key", "value"] data = [(row[0], row[1]) for row in result] assert data == [("test", "key1"), ("test", "key2")] session.commit() + @pytest.mark.asyncio -async def test_amset(sql_aengine: Engine) -> None: +async def test_amset(sql_aengine: AsyncEngine) -> None: """Test that multiple keys can be set.""" store = SQLStore(engine=sql_aengine, namespace="test") await store.acreate_schema() @@ -90,11 +94,14 @@ async def test_amset(sql_aengine: Engine) -> None: await store.amset(key_value_pairs) async with sql_aengine.connect() as session: - result = await session.exec_driver_sql("select * from docstore") + result = await session.exec_driver_sql( + "select * from langchain_key_value_stores" + ) assert result.keys() == ["namespace", "key", "value"] data = [(row[0], row[1]) for row in result] assert data == [("test", "key1"), ("test", "key2")] - session.commit() + await session.commit() + def test_mdelete(sql_engine: Engine) -> None: """Test that deletion works as expected.""" @@ -102,42 +109,53 @@ def test_mdelete(sql_engine: Engine) -> None: store.create_schema() keys = ["key1", "key2"] with sql_engine.connect() as session: - session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key1',:value)") - .bindparams(value=pickle.dumps(b"value1")), - ) - session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key2',:value)") - .bindparams(value=pickle.dumps(b"value2")), - ) + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) session.commit() store.mdelete(keys) with sql_engine.connect() as session: - result = session.exec_driver_sql("select * from docstore") + result = session.exec_driver_sql("select * from langchain_key_value_stores") assert result.keys() == ["namespace", "key", "value"] data = [row for row in result] assert data == [] session.commit() + @pytest.mark.asyncio -async def test_amdelete(sql_aengine: Engine) -> None: +async def test_amdelete(sql_aengine: AsyncEngine) -> None: """Test that deletion works as expected.""" store = SQLStore(engine=sql_aengine, namespace="test") await store.acreate_schema() keys = ["key1", "key2"] async with sql_aengine.connect() as session: - await session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key1',:value)") - .bindparams(value=pickle.dumps(b"value1")), - ) - await session.execute(text("insert into docstore ('namespace', 'key', 'value') " - "values('test','key2',:value)") - .bindparams(value=pickle.dumps(b"value2")), - ) + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) await session.commit() await store.amdelete(keys) async with sql_aengine.connect() as session: - result = await session.exec_driver_sql("select * from docstore") + result = await session.exec_driver_sql( + "select * from langchain_key_value_stores" + ) assert result.keys() == ["namespace", "key", "value"] data = [row for row in result] assert data == [] @@ -155,11 +173,14 @@ def test_yield_keys(sql_engine: Engine) -> None: @pytest.mark.asyncio -async def test_yield_keys(sql_aengine: Engine) -> None: +async def test_ayield_keys(sql_aengine: AsyncEngine) -> None: store = SQLStore(engine=sql_aengine, namespace="test") await store.acreate_schema() key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] await store.amset(key_value_pairs) assert sorted([k async for k in store.ayield_keys()]) == ["key1", "key2"] - assert sorted([k async for k in store.ayield_keys(prefix="key")]) == ["key1", "key2"] + assert sorted([k async for k in store.ayield_keys(prefix="key")]) == [ + "key1", + "key2", + ] assert sorted([k async for k in store.ayield_keys(prefix="lang")]) == [] From 2e12455eaec476c54cfcb403d07a07878618ffc5 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Fri, 7 Jun 2024 13:15:37 +0200 Subject: [PATCH 10/13] Rename Value() --- .../langchain_community/storage/sql.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py index 0767fce9b1899..a92daae1d8c67 100644 --- a/libs/community/langchain_community/storage/sql.py +++ b/libs/community/langchain_community/storage/sql.py @@ -45,7 +45,7 @@ def items_equal(x: Any, y: Any) -> bool: return x == y -class Value(Base): # type: ignore[valid-type,misc] +class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc] """Table used to save values.""" # ATTENTION: @@ -149,10 +149,10 @@ async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: assert isinstance(self.engine, AsyncEngine) result: Dict[str, bytes] = {} async with self._make_async_session() as session: - stmt = select(Value).filter( + stmt = select(LangchainKeyValueStores).filter( and_( - Value.key.in_(keys), - Value.namespace == self.namespace, + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, ) ) for v in await session.scalars(stmt): @@ -163,10 +163,10 @@ def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: result = {} with self._make_sync_session() as session: - stmt = select(Value).filter( + stmt = select(LangchainKeyValueStores).filter( and_( - Value.key.in_(keys), - Value.namespace == self.namespace, + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, ) ) for v in session.scalars(stmt): @@ -178,7 +178,7 @@ async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: await self._amdelete([key for key, _ in key_value_pairs], session) session.add_all( [ - Value(namespace=self.namespace, key=k, value=v) + LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) for k, v in key_value_pairs ] ) @@ -190,26 +190,26 @@ def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: self._mdelete(list(values.keys()), session) session.add_all( [ - Value(namespace=self.namespace, key=k, value=v) + LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) for k, v in values.items() ] ) session.commit() def _mdelete(self, keys: Sequence[str], session: Session) -> None: - stmt = delete(Value).filter( + stmt = delete(LangchainKeyValueStores).filter( and_( - Value.key.in_(keys), - Value.namespace == self.namespace, + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, ) ) session.execute(stmt) async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None: - stmt = delete(Value).filter( + stmt = delete(LangchainKeyValueStores).filter( and_( - Value.key.in_(keys), - Value.namespace == self.namespace, + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, ) ) await session.execute(stmt) @@ -226,8 +226,8 @@ async def amdelete(self, keys: Sequence[str]) -> None: def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: with self._make_sync_session() as session: - for v in session.query(Value).filter( # type: ignore - Value.namespace == self.namespace + for v in session.query(LangchainKeyValueStores).filter( # type: ignore + LangchainKeyValueStores.namespace == self.namespace ): if str(v.key).startswith(prefix or ""): yield str(v.key) @@ -235,7 +235,9 @@ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: async with self._make_async_session() as session: - stmt = select(Value).filter(Value.namespace == self.namespace) + stmt = select(LangchainKeyValueStores).filter( + LangchainKeyValueStores.namespace == self.namespace + ) for v in await session.scalars(stmt): if str(v.key).startswith(prefix or ""): yield str(v.key) From 21e7558e4f750d37d629d7c4171b667ef77c41d1 Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Fri, 7 Jun 2024 13:21:47 +0200 Subject: [PATCH 11/13] Fix test_dependencies --- libs/community/tests/unit_tests/test_dependencies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/community/tests/unit_tests/test_dependencies.py b/libs/community/tests/unit_tests/test_dependencies.py index ce4389d8313ab..c2ad95e85e0a2 100644 --- a/libs/community/tests/unit_tests/test_dependencies.py +++ b/libs/community/tests/unit_tests/test_dependencies.py @@ -74,6 +74,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: assert test_group_deps == sorted( [ + "aiosqlite", "duckdb-engine", "freezegun", "langchain-core", From 6cea0058ddab5529300fb3ac3a045af4ab008917 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 7 Jun 2024 17:00:40 -0400 Subject: [PATCH 12/13] x --- libs/community/poetry.lock | 93 +++++++++++++++++++---------------- libs/community/pyproject.toml | 1 - 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index f91268cb6d659..fc08f4ac48a65 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -110,21 +110,6 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" -[[package]] -name = "aiosqlite" -version = "0.19.0" -description = "asyncio bridge to the standard sqlite3 module" -optional = false -python-versions = ">=3.7" -files = [ - {file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"}, - {file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"}, -] - -[package.extras] -dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"] -docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"] - [[package]] name = "annotated-types" version = "0.7.0" @@ -2131,7 +2116,7 @@ files = [ [[package]] name = "langchain" -version = "0.2.3" +version = "0.2.2" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -2140,7 +2125,6 @@ develop = true [package.dependencies] aiohttp = "^3.8.3" -async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} langchain-core = "^0.2.0" langchain-text-splitters = "^0.2.0" langsmith = "^0.1.17" @@ -2157,7 +2141,7 @@ url = "../langchain" [[package]] name = "langchain-core" -version = "0.2.5" +version = "0.2.4" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -2166,7 +2150,7 @@ develop = true [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.75" +langsmith = "^0.1.66" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -2194,13 +2178,13 @@ url = "../text-splitters" [[package]] name = "langsmith" -version = "0.1.75" +version = "0.1.73" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.75-py3-none-any.whl", hash = "sha256:d08b08dd6b3fa4da170377f95123d77122ef4c52999d10fff4ae08ff70d07aed"}, - {file = "langsmith-0.1.75.tar.gz", hash = "sha256:61274e144ea94c297dd78ce03e6dfae18459fe9bd8ab5094d61a0c4816561279"}, + {file = "langsmith-0.1.73-py3-none-any.whl", hash = "sha256:38bfcce2cfcf0b2da2e9628b903c9e768e1ce59d450e8a584514c1638c595e93"}, + {file = "langsmith-0.1.73.tar.gz", hash = "sha256:0055471cb1fddb76ec65499716764ad0b0314affbdf33ff1f72ad5e2d6a3b224"}, ] [package.dependencies] @@ -3038,8 +3022,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5100,6 +5084,20 @@ files = [ cryptography = ">=35.0.0" types-pyOpenSSL = "*" +[[package]] +name = "types-requests" +version = "2.31.0.6" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, + {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, +] + +[package.dependencies] +types-urllib3 = "*" + [[package]] name = "types-requests" version = "2.32.0.20240602" @@ -5136,6 +5134,17 @@ files = [ {file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"}, ] +[[package]] +name = "types-urllib3" +version = "1.26.25.14" +description = "Typing stubs for urllib3" +optional = false +python-versions = "*" +files = [ + {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, + {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, +] + [[package]] name = "typing-extensions" version = "4.12.1" @@ -5187,6 +5196,22 @@ files = [ [package.extras] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] +[[package]] +name = "urllib3" +version = "1.26.18" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +files = [ + {file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"}, + {file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.1" @@ -5204,23 +5229,6 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] -[[package]] -name = "vcrpy" -version = "4.3.0" -description = "Automatically mock your HTTP interactions to simplify and speed up testing" -optional = false -python-versions = ">=3.7" -files = [ - {file = "vcrpy-4.3.0-py2.py3-none-any.whl", hash = "sha256:8fbd4be412e8a7f35f623dd61034e6380a1c8dbd0edf6e87277a3289f6e98093"}, - {file = "vcrpy-4.3.0.tar.gz", hash = "sha256:49c270ce67e826dba027d83e20d25b67a5885487697e97bca6dbdf53d750a0ac"}, -] - -[package.dependencies] -PyYAML = "*" -six = ">=1.5" -wrapt = "*" -yarl = "*" - [[package]] name = "vcrpy" version = "6.0.1" @@ -5233,6 +5241,7 @@ files = [ [package.dependencies] PyYAML = "*" +urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\" or python_version < \"3.10\""} wrapt = "*" yarl = "*" @@ -5642,4 +5651,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "7b0a65085916f44b5de275c204483691a5433027647189fe8a91367eec56a491" +content-hash = "21ead159a299fcc5cc0a9038ddcee5b4f355c893f23ef80456b72941ad3122fd" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index e27ef2652257e..3d4ec6797b183 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -31,7 +31,6 @@ optional = true pytest = "^7.3.0" pytest-cov = "^4.1.0" pytest-dotenv = "^0.5.2" -aiosqlite = "^0.19.0" duckdb-engine = "^0.11.0" pytest-watcher = "^0.2.6" freezegun = "^1.2.2" From d95bf3db2a6dc956c169901db8e29656ab79d26c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 7 Jun 2024 17:01:28 -0400 Subject: [PATCH 13/13] Remove from test dependencies --- libs/community/tests/unit_tests/test_dependencies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/community/tests/unit_tests/test_dependencies.py b/libs/community/tests/unit_tests/test_dependencies.py index c2ad95e85e0a2..ce4389d8313ab 100644 --- a/libs/community/tests/unit_tests/test_dependencies.py +++ b/libs/community/tests/unit_tests/test_dependencies.py @@ -74,7 +74,6 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: assert test_group_deps == sorted( [ - "aiosqlite", "duckdb-engine", "freezegun", "langchain-core",