Skip to content

updated postgresql checkpoint to include metadata. #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions langchain_postgres/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import psycopg
from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple
from langgraph.checkpoint.base import Checkpoint, CheckpointThreadTs, CheckpointTuple, CheckpointMetadata
from psycopg_pool import AsyncConnectionPool, ConnectionPool


Expand Down Expand Up @@ -248,6 +248,7 @@ def create_tables(connection: Union[psycopg.Connection, ConnectionPool], /) -> N
checkpoint BYTEA NOT NULL,
thread_ts TIMESTAMPTZ NOT NULL,
parent_ts TIMESTAMPTZ,
metadata BYTEA,
PRIMARY KEY (thread_id, thread_ts)
);
"""
Expand All @@ -267,6 +268,7 @@ async def acreate_tables(
checkpoint BYTEA NOT NULL,
thread_ts TIMESTAMPTZ NOT NULL,
parent_ts TIMESTAMPTZ,
metadata BYTEA,
PRIMARY KEY (thread_id, thread_ts)
);
"""
Expand All @@ -284,7 +286,7 @@ async def adrop_tables(connection: psycopg.AsyncConnection, /) -> None:
async with connection.cursor() as cur:
await cur.execute("DROP TABLE IF EXISTS checkpoints;")

def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig:
def put(self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata) -> RunnableConfig:
"""Put the checkpoint for the given configuration.

Args:
Expand All @@ -308,7 +310,7 @@ def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig:
INSERT INTO checkpoints
(thread_id, thread_ts, parent_ts, checkpoint)
VALUES
(%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s)
(%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s, %(metadata)s)
ON CONFLICT (thread_id, thread_ts)
DO UPDATE SET checkpoint = EXCLUDED.checkpoint;
""",
Expand All @@ -317,6 +319,7 @@ def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig:
"thread_ts": checkpoint["ts"],
"parent_ts": parent_ts if parent_ts else None,
"checkpoint": self.serializer.dumps(checkpoint),
"metadata": self.serializer.dumps(metadata),
},
)

Expand All @@ -328,7 +331,7 @@ def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> RunnableConfig:
}

async def aput(
self, config: RunnableConfig, checkpoint: Checkpoint
self, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata
) -> RunnableConfig:
"""Put the checkpoint for the given configuration.

Expand All @@ -352,7 +355,7 @@ async def aput(
INSERT INTO
checkpoints (thread_id, thread_ts, parent_ts, checkpoint)
VALUES
(%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s)
(%(thread_id)s, %(thread_ts)s, %(parent_ts)s, %(checkpoint)s, %(metadata)s)
ON CONFLICT (thread_id, thread_ts)
DO UPDATE SET checkpoint = EXCLUDED.checkpoint;
""",
Expand All @@ -361,6 +364,7 @@ async def aput(
"thread_ts": checkpoint["ts"],
"parent_ts": parent_ts if parent_ts else None,
"checkpoint": self.serializer.dumps(checkpoint),
"metadata": self.serializer.dumps(metadata),
},
)

Expand All @@ -371,13 +375,13 @@ async def aput(
},
}

def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None]:
def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None, None]:
"""Get all the checkpoints for the given configuration."""
with self._get_sync_connection() as conn:
with conn.cursor() as cur:
thread_id = config["configurable"]["thread_id"]
cur.execute(
"SELECT checkpoint, thread_ts, parent_ts "
"SELECT checkpoint, thread_ts, parent_ts, metadata "
"FROM checkpoints "
"WHERE thread_id = %(thread_id)s "
"ORDER BY thread_ts DESC",
Expand All @@ -402,6 +406,7 @@ def list(self, config: RunnableConfig) -> Generator[CheckpointTuple, None, None]
}
if value[2]
else None,
metadata = value[3]
)

async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
Expand All @@ -410,7 +415,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
async with conn.cursor() as cur:
thread_id = config["configurable"]["thread_id"]
await cur.execute(
"SELECT checkpoint, thread_ts, parent_ts "
"SELECT checkpoint, thread_ts, parent_ts, metadata "
"FROM checkpoints "
"WHERE thread_id = %(thread_id)s "
"ORDER BY thread_ts DESC",
Expand All @@ -435,6 +440,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
}
if value[2]
else None,
metadata = value[3]
)

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Expand All @@ -458,7 +464,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
with conn.cursor() as cur:
if thread_ts:
cur.execute(
"SELECT checkpoint, parent_ts "
"SELECT checkpoint, parent_ts, metadata "
"FROM checkpoints "
"WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s",
{
Expand All @@ -479,10 +485,11 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
}
if value[1]
else None,
metadata = value[2]
)
else:
cur.execute(
"SELECT checkpoint, thread_ts, parent_ts "
"SELECT checkpoint, thread_ts, parent_ts, metadata "
"FROM checkpoints "
"WHERE thread_id = %(thread_id)s "
"ORDER BY thread_ts DESC LIMIT 1",
Expand All @@ -508,6 +515,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
}
if value[2]
else None,
metadata = value[3]
)
return None

Expand All @@ -532,7 +540,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
async with conn.cursor() as cur:
if thread_ts:
await cur.execute(
"SELECT checkpoint, parent_ts "
"SELECT checkpoint, parent_ts, metadata "
"FROM checkpoints "
"WHERE thread_id = %(thread_id)s AND thread_ts = %(thread_ts)s",
{
Expand All @@ -553,10 +561,11 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
}
if value[1]
else None,
metadata = value[2]
)
else:
await cur.execute(
"SELECT checkpoint, thread_ts, parent_ts "
"SELECT checkpoint, thread_ts, parent_ts, metadata "
"FROM checkpoints "
"WHERE thread_id = %(thread_id)s "
"ORDER BY thread_ts DESC LIMIT 1",
Expand All @@ -582,6 +591,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
}
if value[2]
else None,
metadata = value[3]
)

return None
Loading