-
Notifications
You must be signed in to change notification settings - Fork 85
feat: Adding multi modal support for PGVectorStore #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
43f89a5
97b387c
3dc9cc5
b5bb4ff
f9f5337
ffe8c7a
f45e4de
273a57b
3dfbad6
aaa1514
9b41ade
cc26044
d68c75c
9efdac8
92663b9
5c35c6f
a5399a4
477a038
39dc8f1
7532fdf
0287543
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,16 @@ | ||
# TODO: Remove below import when minimum supported Python version is 3.10 | ||
from __future__ import annotations | ||
|
||
import base64 | ||
import copy | ||
import json | ||
import re | ||
import uuid | ||
from typing import Any, Callable, Iterable, Optional, Sequence | ||
|
||
import numpy as np | ||
import requests | ||
from google.cloud import storage # type: ignore | ||
from langchain_core.documents import Document | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.vectorstores import VectorStore, utils | ||
|
@@ -365,6 +369,92 @@ async def aadd_documents( | |
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) | ||
return ids | ||
|
||
def _encode_image(self, uri: str) -> str: | ||
"""Get base64 string from a image URI.""" | ||
gcs_uri = re.match("gs://(.*?)/(.*)", uri) | ||
dishaprakash marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if gcs_uri: | ||
bucket_name, object_name = gcs_uri.groups() | ||
storage_client = storage.Client() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want to wrap this in a try except block to provide a more clear error or do you think the error is clear if they are not running in a Google Cloud environment or have set up credentials. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other langchain packages don't have running integrations tests for 3P providers. We could mock this test or just test this functionality in our package downstream. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently this is the error The options for the tests are:
What do you suggest? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If GCS storage call can be easily mock, let's go ahead and do that. If it can't let's keep the test but skip it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, the mock solutions may need more debugging, I've removed the gcs uri from being tested, the other images being created locally are still under the test. |
||
bucket = storage_client.bucket(bucket_name) | ||
blob = bucket.blob(object_name) | ||
return base64.b64encode(blob.download_as_bytes()).decode("utf-8") | ||
|
||
web_uri = re.match(r"^(https?://).*", uri) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regular expressions are not a good solution here. Simple prefix matching is more robust and probably will also be faster in terms of actual run time There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Used urlparse instead. If prefix matching is preferred, I can make that change |
||
if web_uri: | ||
response = requests.get(uri, stream=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an SSRF attack |
||
response.raise_for_status() | ||
return base64.b64encode(response.content).decode("utf-8") | ||
|
||
with open(uri, "rb") as image_file: | ||
return base64.b64encode(image_file.read()).decode("utf-8") | ||
|
||
async def aadd_images( | ||
self, | ||
uris: list[str], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accepting URIs without safe guards is an SSRF attack There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that SSRF attacks are generally dealt with by the application layer. Is that correct, or is it more of a framework responsibility? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The way I'd think about this is:
Given that this code is supposed to be optimized for production, there really isn't a reason to access the local file system. |
||
metadatas: Optional[list[dict]] = None, | ||
ids: Optional[list[str]] = None, | ||
**kwargs: Any, | ||
) -> list[str]: | ||
"""Embed images and add to the table. | ||
|
||
Args: | ||
uris (list[str]): List of local image URIs to add to the table. | ||
metadatas (Optional[list[dict]]): List of metadatas to add to table records. | ||
ids: (Optional[list[str]]): List of IDs to add to table records. | ||
|
||
Returns: | ||
List of record IDs added. | ||
""" | ||
encoded_images = [] | ||
if metadatas is None: | ||
metadatas = [{"image_uri": uri} for uri in uris] | ||
|
||
for uri in uris: | ||
encoded_image = self._encode_image(uri) | ||
encoded_images.append(encoded_image) | ||
|
||
embeddings = self._images_embedding_helper(uris) | ||
ids = await self.aadd_embeddings( | ||
encoded_images, embeddings, metadatas=metadatas, ids=ids, **kwargs | ||
) | ||
return ids | ||
|
||
def _images_embedding_helper(self, image_uris: list[str]) -> list[list[float]]: | ||
# check if either `embed_images()` or `embed_image()` API is supported by the embedding service used | ||
if hasattr(self.embedding_service, "embed_images"): | ||
try: | ||
embeddings = self.embedding_service.embed_images(image_uris) | ||
except Exception as e: | ||
raise Exception( | ||
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}" | ||
) | ||
elif hasattr(self.embedding_service, "embed_image"): | ||
try: | ||
embeddings = self.embedding_service.embed_image(image_uris) | ||
except Exception as e: | ||
raise Exception( | ||
f"Make sure your selected embedding model supports list of image URIs as input. {str(e)}" | ||
) | ||
else: | ||
raise ValueError( | ||
"Please use an embedding model that supports image embedding." | ||
) | ||
return embeddings | ||
|
||
async def asimilarity_search_image( | ||
self, | ||
image_uri: str, | ||
k: Optional[int] = None, | ||
filter: Optional[dict] = None, | ||
**kwargs: Any, | ||
) -> list[Document]: | ||
"""Return docs selected by similarity search on query.""" | ||
embedding = self._images_embedding_helper([image_uri])[0] | ||
|
||
return await self.asimilarity_search_by_vector( | ||
embedding=embedding, k=k, filter=filter, **kwargs | ||
) | ||
|
||
async def adelete( | ||
self, | ||
ids: Optional[list] = None, | ||
|
@@ -1268,3 +1358,25 @@ def max_marginal_relevance_search_with_score_by_vector( | |
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead." | ||
) | ||
|
||
def add_images( | ||
self, | ||
uris: list[str], | ||
metadatas: Optional[list[dict]] = None, | ||
ids: Optional[list[str]] = None, | ||
**kwargs: Any, | ||
) -> list[str]: | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." | ||
) | ||
|
||
def similarity_search_image( | ||
self, | ||
image_uri: str, | ||
k: Optional[int] = None, | ||
filter: Optional[dict] = None, | ||
**kwargs: Any, | ||
) -> list[Document]: | ||
raise NotImplementedError( | ||
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the import from global namespace.
I'm not sure how I should use the key-value store in this case. Could you please point me to the right usage?