generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 20
feat: Added WatsonxRerank integration, update logic of passing params
#33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ee26a2c
add: Added WatsonxRerank integration
MateuszOssGit a84ddb5
update logic of passing params in method
MateuszOssGit eefce98
fixed test_imports
MateuszOssGit f07b518
update logic of passing params in llm
MateuszOssGit a76cf58
added extract_params method, simplify code
MateuszOssGit 3a24188
make format, make lint
MateuszOssGit f2ef68c
Added standard EmbeddingsIntegrationTests
MateuszOssGit e832883
Added standard TestWatsonxEmbeddingsStandard
MateuszOssGit bc7434d
Added rerank unit tests
MateuszOssGit 85403f5
Update Chat standard tests
MateuszOssGit 573fcc6
Simplify of using token in chat requests
MateuszOssGit 0395f9b
v1
MateuszOssGit 7ee61bf
Update standard chat tests
MateuszOssGit 72b7acf
Merge branch 'main' into dev-added-rerank
MateuszOssGit 369a52f
Update property method on chat standard tests
MateuszOssGit af354ad
poetry update
MateuszOssGit 0b3267c
update chat models standard
MateuszOssGit File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from langchain_ibm.chat_models import ChatWatsonx | ||
from langchain_ibm.embeddings import WatsonxEmbeddings | ||
from langchain_ibm.llms import WatsonxLLM | ||
from langchain_ibm.rerank import WatsonxRerank | ||
|
||
__all__ = ["WatsonxLLM", "WatsonxEmbeddings", "ChatWatsonx"] | ||
__all__ = ["WatsonxLLM", "WatsonxEmbeddings", "ChatWatsonx", "WatsonxRerank"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
from __future__ import annotations | ||
|
||
from copy import deepcopy | ||
from typing import Any, Dict, List, Optional, Sequence, Union | ||
|
||
from ibm_watsonx_ai import APIClient, Credentials # type: ignore | ||
from ibm_watsonx_ai.foundation_models import Rerank # type: ignore | ||
from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore | ||
BaseSchema, | ||
RerankParameters, | ||
) | ||
from langchain_core.callbacks import Callbacks | ||
from langchain_core.documents import BaseDocumentCompressor, Document | ||
from langchain_core.utils.utils import secret_from_env | ||
from pydantic import ConfigDict, Field, SecretStr, model_validator | ||
from typing_extensions import Self | ||
|
||
from langchain_ibm.utils import check_for_attribute | ||
|
||
|
||
class WatsonxRerank(BaseDocumentCompressor): | ||
"""Document compressor that uses `watsonx Rerank API`.""" | ||
|
||
model_id: str | ||
"""Type of model to use.""" | ||
|
||
project_id: Optional[str] = None | ||
"""ID of the Watson Studio project.""" | ||
|
||
space_id: Optional[str] = None | ||
"""ID of the Watson Studio space.""" | ||
|
||
url: SecretStr = Field( | ||
alias="url", default_factory=secret_from_env("WATSONX_URL", default=None) | ||
) | ||
"""URL to the Watson Machine Learning or CPD instance.""" | ||
|
||
apikey: Optional[SecretStr] = Field( | ||
alias="apikey", default_factory=secret_from_env("WATSONX_APIKEY", default=None) | ||
) | ||
"""API key to the Watson Machine Learning or CPD instance.""" | ||
|
||
token: Optional[SecretStr] = Field( | ||
alias="token", default_factory=secret_from_env("WATSONX_TOKEN", default=None) | ||
) | ||
"""Token to the CPD instance.""" | ||
|
||
password: Optional[SecretStr] = Field( | ||
alias="password", | ||
default_factory=secret_from_env("WATSONX_PASSWORD", default=None), | ||
) | ||
"""Password to the CPD instance.""" | ||
|
||
username: Optional[SecretStr] = Field( | ||
alias="username", | ||
default_factory=secret_from_env("WATSONX_USERNAME", default=None), | ||
) | ||
"""Username to the CPD instance.""" | ||
|
||
instance_id: Optional[SecretStr] = Field( | ||
alias="instance_id", | ||
default_factory=secret_from_env("WATSONX_INSTANCE_ID", default=None), | ||
) | ||
"""Instance_id of the CPD instance.""" | ||
|
||
version: Optional[SecretStr] = None | ||
"""Version of the CPD instance.""" | ||
|
||
params: Optional[Union[dict, RerankParameters]] = None | ||
"""Model parameters to use during request generation.""" | ||
|
||
verify: Union[str, bool, None] = None | ||
"""You can pass one of following as verify: | ||
* the path to a CA_BUNDLE file | ||
* the path of directory with certificates of trusted CAs | ||
* True - default path to truststore will be taken | ||
* False - no verification will be made""" | ||
|
||
validate_model: bool = True | ||
"""Model ID validation.""" | ||
|
||
streaming: bool = False | ||
""" Whether to stream the results or not. """ | ||
|
||
watsonx_rerank: Rerank = Field(default=None, exclude=True) #: :meta private: | ||
|
||
watsonx_client: Optional[APIClient] = Field(default=None, exclude=True) | ||
|
||
model_config = ConfigDict( | ||
populate_by_name=True, | ||
arbitrary_types_allowed=True, | ||
extra="forbid", | ||
) | ||
|
||
@property | ||
def lc_secrets(self) -> Dict[str, str]: | ||
"""A map of constructor argument names to secret ids. | ||
|
||
For example: | ||
{ | ||
"url": "WATSONX_URL", | ||
"apikey": "WATSONX_APIKEY", | ||
"token": "WATSONX_TOKEN", | ||
"password": "WATSONX_PASSWORD", | ||
"username": "WATSONX_USERNAME", | ||
"instance_id": "WATSONX_INSTANCE_ID", | ||
} | ||
""" | ||
return { | ||
"url": "WATSONX_URL", | ||
"apikey": "WATSONX_APIKEY", | ||
"token": "WATSONX_TOKEN", | ||
"password": "WATSONX_PASSWORD", | ||
"username": "WATSONX_USERNAME", | ||
"instance_id": "WATSONX_INSTANCE_ID", | ||
} | ||
|
||
@model_validator(mode="after") | ||
def validate_environment(self) -> Self: | ||
"""Validate that credentials and python package exists in environment.""" | ||
if isinstance(self.watsonx_client, APIClient): | ||
watsonx_rerank = Rerank( | ||
model_id=self.model_id, | ||
params=self.params, | ||
api_client=self.watsonx_client, | ||
project_id=self.project_id, | ||
space_id=self.space_id, | ||
verify=self.verify, | ||
) | ||
self.watsonx_rerank = watsonx_rerank | ||
|
||
else: | ||
check_for_attribute(self.url, "url", "WATSONX_URL") | ||
|
||
if "cloud.ibm.com" in self.url.get_secret_value(): | ||
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY") | ||
else: | ||
if not self.token and not self.password and not self.apikey: | ||
raise ValueError( | ||
"Did not find 'token', 'password' or 'apikey'," | ||
" please add an environment variable" | ||
" `WATSONX_TOKEN`, 'WATSONX_PASSWORD' or 'WATSONX_APIKEY' " | ||
"which contains it," | ||
" or pass 'token', 'password' or 'apikey'" | ||
" as a named parameter." | ||
) | ||
elif self.token: | ||
check_for_attribute(self.token, "token", "WATSONX_TOKEN") | ||
elif self.password: | ||
check_for_attribute(self.password, "password", "WATSONX_PASSWORD") | ||
check_for_attribute(self.username, "username", "WATSONX_USERNAME") | ||
elif self.apikey: | ||
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY") | ||
check_for_attribute(self.username, "username", "WATSONX_USERNAME") | ||
|
||
if not self.instance_id: | ||
check_for_attribute( | ||
self.instance_id, "instance_id", "WATSONX_INSTANCE_ID" | ||
) | ||
|
||
credentials = Credentials( | ||
url=self.url.get_secret_value() if self.url else None, | ||
api_key=self.apikey.get_secret_value() if self.apikey else None, | ||
token=self.token.get_secret_value() if self.token else None, | ||
password=self.password.get_secret_value() if self.password else None, | ||
username=self.username.get_secret_value() if self.username else None, | ||
instance_id=self.instance_id.get_secret_value() | ||
if self.instance_id | ||
else None, | ||
version=self.version.get_secret_value() if self.version else None, | ||
verify=self.verify, | ||
) | ||
|
||
watsonx_rerank = Rerank( | ||
model_id=self.model_id, | ||
credentials=credentials, | ||
params=self.params, | ||
project_id=self.project_id, | ||
space_id=self.space_id, | ||
verify=self.verify, | ||
) | ||
self.watsonx_rerank = watsonx_rerank | ||
|
||
return self | ||
|
||
def rerank( | ||
self, | ||
documents: Sequence[Union[str, Document, dict]], | ||
query: str, | ||
**kwargs: Any, | ||
) -> List[Dict[str, Any]]: | ||
if len(documents) == 0: # to avoid empty api call | ||
return [] | ||
docs = [ | ||
doc.page_content if isinstance(doc, Document) else doc for doc in documents | ||
] | ||
params = self._get_rerank_params(**kwargs) | ||
|
||
results = self.watsonx_rerank.generate( | ||
query=query, inputs=docs, **(kwargs | {"params": params}) | ||
) | ||
result_dicts = [] | ||
for res in results["results"]: | ||
result_dicts.append( | ||
{"index": res.get("index"), "relevance_score": res.get("score")} | ||
) | ||
return result_dicts | ||
|
||
def compress_documents( | ||
self, | ||
documents: Sequence[Document], | ||
query: str, | ||
callbacks: Optional[Callbacks] = None, | ||
**kwargs: Any, | ||
) -> Sequence[Document]: | ||
""" | ||
Compress documents using watsonx's rerank API. | ||
|
||
Args: | ||
documents: A sequence of documents to compress. | ||
query: The query to use for compressing the documents. | ||
callbacks: Callbacks to run during the compression process. | ||
|
||
Returns: | ||
A sequence of compressed documents. | ||
""" | ||
compressed = [] | ||
for res in self.rerank(documents, query, **kwargs): | ||
doc = documents[res["index"]] | ||
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) | ||
doc_copy.metadata["relevance_score"] = res["relevance_score"] | ||
compressed.append(doc_copy) | ||
return compressed | ||
|
||
def _get_rerank_params(self, **kwargs: Any) -> Dict[str, Any]: | ||
if kwargs.get("params") is not None: | ||
params = kwargs.get("params") | ||
elif self.params is not None: | ||
params = self.params | ||
else: | ||
params = None | ||
|
||
if isinstance(params, BaseSchema): | ||
params = params.to_dict() | ||
|
||
return params or {} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.