From 4f5266c9c2e6c375a0ffa75538defb1f117c6b27 Mon Sep 17 00:00:00 2001 From: Brian Pennington Date: Fri, 17 May 2024 11:29:18 -0500 Subject: [PATCH] updated checkpoint to include metadata. Fixed critical bug for postgres checkpoint. --- langchain_postgres/checkpoint.py | 34 +++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/langchain_postgres/checkpoint.py b/langchain_postgres/checkpoint.py index 49522098..8a7d3af6 100644 --- a/langchain_postgres/checkpoint.py +++ b/langchain_postgres/checkpoint.py @@ -7,7 +7,7 @@ import psycopg from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig from langgraph.checkpoint import BaseCheckpointSaver -from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple +from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple, CheckpointMetadata from psycopg_pool import AsyncConnectionPool, ConnectionPool @@ -248,6 +248,7 @@ def create_tables(connection: Union[psycopg.Connection, ConnectionPool], /) -> N checkpoint BYTEA NOT NULL, thread_ts TIMESTAMPTZ NOT NULL, parent_ts TIMESTAMPTZ, + metadata BYTEA, PRIMARY KEY (thread_id, thread_ts) ); """ @@ -267,6 +268,7 @@ async def acreate_tables( checkpoint BYTEA NOT NULL, thread_ts TIMESTAMPTZ NOT NULL, parent_ts TIMESTAMPTZ, + metadata BYTEA, PRIMARY KEY (thread_id, thread_ts) ); """ @@ -284,7 +286,7 @@ async def adrop_tables(connection: psycopg.AsyncConnection, /) -> None: async with connection.cursor() as cur: await cur.execute("DROP TABLE IF EXISTS checkpoints;") - def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig: + def put(self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata) -> RunnableConfig: """Put the checkpoint for the given configuration. Args: @@ -308,7 +310,7 @@ def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig: INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint) VALUES - (%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s) + (%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s, %(metadata)s) ON CONFLICT (thread_id, thread_ts) DO UPDATE SET checkpoint = EXCLUDED.checkpoint; """, @@ -317,6 +319,7 @@ def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig: "thread_ts": checkpoint["ts"], "parent_ts": parent_ts if parent_ts else None, "checkpoint": self.serializer.dumps(checkpoint), + "metadata": self.serializer.dumps(metadata), }, ) @@ -328,7 +331,7 @@ def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig: } async def aput( - self, config: RunnableConfig, checkpoint: Checkpoint + self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata ) -> RunnableConfig: """Put the checkpoint for the given configuration. @@ -352,7 +355,7 @@ async def aput( INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint) VALUES - (%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s) + (%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s, %(metadata)s) ON CONFLICT (thread_id, thread_ts) DO UPDATE SET checkpoint = EXCLUDED.checkpoint; """, @@ -361,6 +364,7 @@ async def aput( "thread_ts": checkpoint["ts"], "parent_ts": parent_ts if parent_ts else None, "checkpoint": self.serializer.dumps(checkpoint), + "metadata": self.serializer.dumps(metadata), }, ) @@ -371,13 +375,13 @@ async def aput( }, } - def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None]: + def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None, None]: """Get all the checkpoints for the given configuration.""" with self._get_sync_connection() as conn: with conn.cursor() as cur: thread_id = config["configurable"]["thread_id"] cur.execute( - "SELECT checkpoint, thread_ts, parent_ts " + "SELECT checkpoint, thread_ts, parent_ts, metadata " "FROM checkpoints " "WHERE thread_id = %(thread_id)s " "ORDER BY thread_ts DESC", @@ -402,6 +406,7 @@ def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None] } if value[2] else None, + metadata = value[3] ) async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]: @@ -410,7 +415,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]: async with conn.cursor() as cur: thread_id = config["configurable"]["thread_id"] await cur.execute( - "SELECT checkpoint, thread_ts, parent_ts " + "SELECT checkpoint, thread_ts, parent_ts, metadata " "FROM checkpoints " "WHERE thread_id = %(thread_id)s " "ORDER BY thread_ts DESC", @@ -435,6 +440,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]: } if value[2] else None, + metadata = value[3] ) def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: @@ -458,7 +464,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: with conn.cursor() as cur: if thread_ts: cur.execute( - "SELECT checkpoint, parent_ts " + "SELECT checkpoint, parent_ts, metadata " "FROM checkpoints " "WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s", { @@ -479,10 +485,11 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: } if value[1] else None, + metadata = value[2] ) else: cur.execute( - "SELECT checkpoint, thread_ts, parent_ts " + "SELECT checkpoint, thread_ts, parent_ts, metadata " "FROM checkpoints " "WHERE thread_id = %(thread_id)s " "ORDER BY thread_ts DESC LIMIT 1", @@ -508,6 +515,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: } if value[2] else None, + metadata = value[3] ) return None @@ -532,7 +540,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: async with conn.cursor() as cur: if thread_ts: await cur.execute( - "SELECT checkpoint, parent_ts " + "SELECT checkpoint, parent_ts, metadata " "FROM checkpoints " "WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s", { @@ -553,10 +561,11 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: } if value[1] else None, + metadata = value[2] ) else: await cur.execute( - "SELECT checkpoint, thread_ts, parent_ts " + "SELECT checkpoint, thread_ts, parent_ts, metadata " "FROM checkpoints " "WHERE thread_id = %(thread_id)s " "ORDER BY thread_ts DESC LIMIT 1", @@ -582,6 +591,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: } if value[2] else None, + metadata = value[3] ) return None