Skip to content

Commit eebdd41

Browse files
authored
Fix/serialization of llm plugins (#445)
2 parents b0ad6e5 + 9ec5850 commit eebdd41

File tree

5 files changed

+31
-35
lines changed

5 files changed

+31
-35
lines changed

ads/llm/langchain/plugins/base.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,20 @@
33

44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6-
from enum import Enum
76
from typing import Any, Dict, List, Optional
87

98
from langchain.llms.base import LLM
10-
from langchain.pydantic_v1 import BaseModel, root_validator, Field
9+
from langchain.load.serializable import Serializable
10+
from langchain.pydantic_v1 import BaseModel, Field, root_validator
11+
1112
from ads.common.auth import default_signer
1213
from ads.config import COMPARTMENT_OCID
1314

1415

15-
class StrEnum(str, Enum):
16-
"""Enum with string members
17-
https://docs.python.org/3.11/library/enum.html#enum.StrEnum
18-
"""
19-
20-
# Pydantic uses Python's standard enum classes to define choices.
21-
# https://docs.pydantic.dev/latest/api/standard_library_types/#enum
22-
23-
24-
class BaseLLM(LLM):
16+
class BaseLLM(LLM, Serializable):
2517
"""Base OCI LLM class. Contains common attributes."""
2618

27-
auth: dict = Field(default_factory=default_signer)
19+
auth: dict = Field(default_factory=default_signer, exclude=True)
2820
"""ADS auth dictionary for OCI authentication.
2921
This can be generated by calling `ads.common.auth.api_keys()` or `ads.common.auth.resource_principal()`.
3022
If this is not provided then the `ads.common.default_signer()` will be used."""
@@ -54,8 +46,7 @@ def _print_response(self, completion, response):
5446

5547
@classmethod
5648
def get_lc_namespace(cls) -> List[str]:
57-
"""Get the namespace of the langchain object.
58-
"""
49+
"""Get the namespace of the langchain object."""
5950
return ["ads", "llm"]
6051

6152
@classmethod
@@ -68,11 +59,6 @@ class GenerativeAiClientModel(BaseModel):
6859
client: Any #: :meta private:
6960
"""OCI GenerativeAiClient."""
7061

71-
auth: dict = Field(default_factory=default_signer)
72-
"""ADS auth dictionary for OCI authentication.
73-
This can be generated by calling `ads.common.auth.api_keys()` or `ads.common.auth.resource_principal()`.
74-
If this is not provided then the `ads.common.default_signer()` will be used."""
75-
7662
compartment_id: str
7763
"""Compartment ID of the caller."""
7864

ads/llm/langchain/plugins/contant.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,17 @@
33

44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
from enum import Enum
7+
8+
9+
class StrEnum(str, Enum):
10+
"""Enum with string members
11+
https://docs.python.org/3.11/library/enum.html#enum.StrEnum
12+
"""
13+
14+
# Pydantic uses Python's standard enum classes to define choices.
15+
# https://docs.pydantic.dev/latest/api/standard_library_types/#enum
616

7-
from ads.llm.langchain.plugins.base import StrEnum
817

918
DEFAULT_TIME_OUT = 300
1019
DEFAULT_CONTENT_TYPE_JSON = "application/json"

ads/llm/langchain/plugins/llm_gen_ai.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Dict, List, Optional
99

1010
from langchain.callbacks.manager import CallbackManagerForLLMRun
11-
from oci.util import to_dict
11+
1212
from ads.llm.langchain.plugins.base import BaseLLM, GenerativeAiClientModel
1313
from ads.llm.langchain.plugins.contant import *
1414

@@ -31,10 +31,10 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
3131
3232
"""
3333

34-
task: Task = Task.TEXT_GENERATION
34+
task: str = "text_generation"
3535
"""Indicates the task."""
3636

37-
model: Optional[str] = OCIGenerativeAIModel.COHERE_COMMAND
37+
model: Optional[str] = "cohere.command"
3838
"""Model name to use."""
3939

4040
frequency_penalty: float = None
@@ -46,13 +46,13 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
4646
truncate: Optional[str] = None
4747
"""Specify how the client handles inputs longer than the maximum token."""
4848

49-
length: str = LengthParam.AUTO
49+
length: str = "AUTO"
5050
"""Indicates the approximate length of the summary. """
5151

52-
format: str = FormatParam.PARAGRAPH
52+
format: str = "PARAGRAPH"
5353
"""Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points."""
5454

55-
extractiveness: str = ExtractivenessParam.AUTO
55+
extractiveness: str = "AUTO"
5656
"""Controls how close to the original text the summary is. High extractiveness summaries will lean towards reusing sentences verbatim, while low extractiveness summaries will tend to paraphrase more."""
5757

5858
additional_command: str = ""
@@ -181,8 +181,8 @@ def _process_response(self, response: Any, num_generations: int = 1) -> str:
181181
def completion_with_retry(self, **kwargs: Any) -> Any:
182182
from oci.generative_ai.models import (
183183
GenerateTextDetails,
184-
SummarizeTextDetails,
185184
OnDemandServingMode,
185+
SummarizeTextDetails,
186186
)
187187

188188
# TODO: Add retry logic for OCI

ads/llm/langchain/plugins/llm_md.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import requests
1111
from langchain.callbacks.manager import CallbackManagerForLLMRun
12+
1213
from ads.llm.langchain.plugins.base import BaseLLM
1314
from ads.llm.langchain.plugins.contant import (
1415
DEFAULT_CONTENT_TYPE_JSON,
@@ -25,8 +26,8 @@ class ModelDeploymentLLM(BaseLLM):
2526
"""The uri of the endpoint from the deployed Model Deployment model."""
2627

2728
best_of: int = 1
28-
"""Generates best_of completions server-side and returns the "best"
29-
(the one with the highest log probability per token).
29+
"""Generates best_of completions server-side and returns the "best"
30+
(the one with the highest log probability per token).
3031
"""
3132

3233
@property
@@ -230,7 +231,7 @@ class ModelDeploymentVLLM(ModelDeploymentLLM):
230231
"""Whether to use beam search instead of sampling."""
231232

232233
ignore_eos: bool = False
233-
"""Whether to ignore the EOS token and continue generating tokens after
234+
"""Whether to ignore the EOS token and continue generating tokens after
234235
the EOS token is generated."""
235236

236237
logprobs: Optional[int] = None

ads/llm/serialize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from langchain.load.serializable import Serializable
1919

2020
from ads.common.auth import default_signer
21+
from ads.common.object_storage_details import ObjectStorageDetails
2122
from ads.llm import GenerativeAI, ModelDeploymentVLLM, ModelDeploymentTGI
2223
from ads.llm.chain import GuardrailSequence
2324
from ads.llm.guardrails.base import CustomGuardrailBase
@@ -115,12 +116,11 @@ def ignore_unknown(self, node):
115116
None, _SafeLoaderIgnoreUnknown.ignore_unknown
116117
)
117118

118-
if uri.startswith("oci://"):
119-
storage_options = default_signer()
120-
else:
121-
storage_options = {}
119+
storage_options = default_signer() if ObjectStorageDetails.is_oci_path(uri) else {}
120+
122121
with fsspec.open(uri, **storage_options) as f:
123122
config = yaml.load(f, Loader=_SafeLoaderIgnoreUnknown)
123+
124124
return load(
125125
config, secrets_map=secrets_map, valid_namespaces=valid_namespaces, **kwargs
126126
)

0 commit comments

Comments
 (0)