Skip to content

Commit 9353b2e

Browse files
committed
merge conflicts
2 parents 3e78ff5 + 3418eea commit 9353b2e

File tree

7 files changed

+71
-52
lines changed

7 files changed

+71
-52
lines changed

ads/llm/chain.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
178178
if self.log_info or os.environ.get(LOG_ADS_GUARDRAIL_INFO) == "1":
179179
# LOG_ADS_GUARDRAIL_INFO is set to "1" in score.py by default.
180180
print(obj.dict())
181+
# If the output is a singleton list, take it out of the list.
182+
if isinstance(obj.data, list) and len(obj.data) == 1:
183+
obj.data = obj.data[0]
181184
return obj
182185

183186
def _save_to_file(self, chain_dict, filename, overwrite=False):

ads/llm/guardrails/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __repr__(self) -> str:
7878
steps.append(f"{run_info.name} - {run_info.metrics}")
7979
if run_info:
8080
steps.append(str(run_info.output))
81-
return "\n".join(steps)
81+
return "\n".join(steps) + "\n\n" + str(self)
8282

8383

8484
class BlockedByGuardrail(ToolException):

ads/llm/langchain/plugins/base.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,20 @@
33

44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6-
from enum import Enum
76
from typing import Any, Dict, List, Optional
87

98
from langchain.llms.base import LLM
10-
from langchain.pydantic_v1 import BaseModel, root_validator, Field
9+
from langchain.load.serializable import Serializable
10+
from langchain.pydantic_v1 import BaseModel, Field, root_validator
11+
1112
from ads.common.auth import default_signer
1213
from ads.config import COMPARTMENT_OCID
1314

1415

15-
class StrEnum(str, Enum):
16-
"""Enum with string members
17-
https://docs.python.org/3.11/library/enum.html#enum.StrEnum
18-
"""
19-
20-
# Pydantic uses Python's standard enum classes to define choices.
21-
# https://docs.pydantic.dev/latest/api/standard_library_types/#enum
22-
23-
24-
class BaseLLM(LLM):
16+
class BaseLLM(LLM, Serializable):
2517
"""Base OCI LLM class. Contains common attributes."""
2618

27-
auth: dict = Field(default_factory=default_signer)
19+
auth: dict = Field(default_factory=default_signer, exclude=True)
2820
"""ADS auth dictionary for OCI authentication.
2921
This can be generated by calling `ads.common.auth.api_keys()` or `ads.common.auth.resource_principal()`.
3022
If this is not provided then the `ads.common.default_signer()` will be used."""
@@ -54,8 +46,7 @@ def _print_response(self, completion, response):
5446

5547
@classmethod
5648
def get_lc_namespace(cls) -> List[str]:
57-
"""Get the namespace of the langchain object.
58-
"""
49+
"""Get the namespace of the LangChain object."""
5950
return ["ads", "llm"]
6051

6152
@classmethod
@@ -65,15 +56,12 @@ def is_lc_serializable(cls) -> bool:
6556

6657

6758
class GenerativeAiClientModel(BaseModel):
59+
"""Base model for generative AI embedding model and LLM."""
60+
6861
client: Any #: :meta private:
6962
"""OCI GenerativeAiClient."""
7063

71-
auth: dict = Field(default_factory=default_signer)
72-
"""ADS auth dictionary for OCI authentication.
73-
This can be generated by calling `ads.common.auth.api_keys()` or `ads.common.auth.resource_principal()`.
74-
If this is not provided then the `ads.common.default_signer()` will be used."""
75-
76-
compartment_id: str
64+
compartment_id: str = None
7765
"""Compartment ID of the caller."""
7866

7967
endpoint_kwargs: Dict[str, Any] = {}
@@ -104,7 +92,9 @@ def validate_environment( # pylint: disable=no-self-argument
10492
client_kwargs.update(values["client_kwargs"])
10593
values["client"] = GenerativeAiClient(**auth, **client_kwargs)
10694
# Set default compartment ID
107-
if "compartment_id" not in values and COMPARTMENT_OCID:
108-
values["compartment_id"] = COMPARTMENT_OCID
109-
95+
if not values.get("compartment_id"):
96+
if COMPARTMENT_OCID:
97+
values["compartment_id"] = COMPARTMENT_OCID
98+
else:
99+
raise ValueError("Please specify compartment_id.")
110100
return values

ads/llm/langchain/plugins/contant.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,17 @@
33

44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
from enum import Enum
7+
8+
9+
class StrEnum(str, Enum):
10+
"""Enum with string members
11+
https://docs.python.org/3.11/library/enum.html#enum.StrEnum
12+
"""
13+
14+
# Pydantic uses Python's standard enum classes to define choices.
15+
# https://docs.pydantic.dev/latest/api/standard_library_types/#enum
616

7-
from ads.llm.langchain.plugins.base import StrEnum
817

918
DEFAULT_TIME_OUT = 300
1019
DEFAULT_CONTENT_TYPE_JSON = "application/json"

ads/llm/langchain/plugins/llm_gen_ai.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Dict, List, Optional
99

1010
from langchain.callbacks.manager import CallbackManagerForLLMRun
11-
from oci.util import to_dict
11+
1212
from ads.llm.langchain.plugins.base import BaseLLM, GenerativeAiClientModel
1313
from ads.llm.langchain.plugins.contant import *
1414

@@ -31,10 +31,10 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
3131
3232
"""
3333

34-
task: Task = Task.TEXT_GENERATION
34+
task: str = "text_generation"
3535
"""Indicates the task."""
3636

37-
model: Optional[str] = OCIGenerativeAIModel.COHERE_COMMAND
37+
model: Optional[str] = "cohere.command"
3838
"""Model name to use."""
3939

4040
frequency_penalty: float = None
@@ -46,13 +46,13 @@ class GenerativeAI(GenerativeAiClientModel, BaseLLM):
4646
truncate: Optional[str] = None
4747
"""Specify how the client handles inputs longer than the maximum token."""
4848

49-
length: str = LengthParam.AUTO
49+
length: str = "AUTO"
5050
"""Indicates the approximate length of the summary. """
5151

52-
format: str = FormatParam.PARAGRAPH
52+
format: str = "PARAGRAPH"
5353
"""Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points."""
5454

55-
extractiveness: str = ExtractivenessParam.AUTO
55+
extractiveness: str = "AUTO"
5656
"""Controls how close to the original text the summary is. High extractiveness summaries will lean towards reusing sentences verbatim, while low extractiveness summaries will tend to paraphrase more."""
5757

5858
additional_command: str = ""
@@ -181,8 +181,8 @@ def _process_response(self, response: Any, num_generations: int = 1) -> str:
181181
def completion_with_retry(self, **kwargs: Any) -> Any:
182182
from oci.generative_ai.models import (
183183
GenerateTextDetails,
184-
SummarizeTextDetails,
185184
OnDemandServingMode,
185+
SummarizeTextDetails,
186186
)
187187

188188
# TODO: Add retry logic for OCI

ads/llm/langchain/plugins/llm_md.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import requests
1111
from langchain.callbacks.manager import CallbackManagerForLLMRun
12+
1213
from ads.llm.langchain.plugins.base import BaseLLM
1314
from ads.llm.langchain.plugins.contant import (
1415
DEFAULT_CONTENT_TYPE_JSON,
@@ -25,8 +26,8 @@ class ModelDeploymentLLM(BaseLLM):
2526
"""The uri of the endpoint from the deployed Model Deployment model."""
2627

2728
best_of: int = 1
28-
"""Generates best_of completions server-side and returns the "best"
29-
(the one with the highest log probability per token).
29+
"""Generates best_of completions server-side and returns the "best"
30+
(the one with the highest log probability per token).
3031
"""
3132

3233
@property
@@ -230,7 +231,7 @@ class ModelDeploymentVLLM(ModelDeploymentLLM):
230231
"""Whether to use beam search instead of sampling."""
231232

232233
ignore_eos: bool = False
233-
"""Whether to ignore the EOS token and continue generating tokens after
234+
"""Whether to ignore the EOS token and continue generating tokens after
234235
the EOS token is generated."""
235236

236237
logprobs: Optional[int] = None

ads/llm/serialize.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from opensearchpy.client import OpenSearch
2424

2525
from ads.common.auth import default_signer
26-
from ads.llm import GenerativeAI, ModelDeploymentTGI, ModelDeploymentVLLM
26+
from ads.common.object_storage_details import ObjectStorageDetails
27+
from ads.llm import GenerativeAI, ModelDeploymentVLLM, ModelDeploymentTGI
2728
from ads.llm.chain import GuardrailSequence
2829
from ads.llm.guardrails.base import CustomGuardrailBase
2930
from ads.llm.patch import RunnableParallel, RunnableParallelSerializer
@@ -241,12 +242,11 @@ def ignore_unknown(self, node):
241242
None, _SafeLoaderIgnoreUnknown.ignore_unknown
242243
)
243244

244-
if uri.startswith("oci://"):
245-
storage_options = default_signer()
246-
else:
247-
storage_options = {}
245+
storage_options = default_signer() if ObjectStorageDetails.is_oci_path(uri) else {}
246+
248247
with fsspec.open(uri, **storage_options) as f:
249248
config = yaml.load(f, Loader=_SafeLoaderIgnoreUnknown)
249+
250250
return load(
251251
config, secrets_map=secrets_map, valid_namespaces=valid_namespaces, **kwargs
252252
)
@@ -275,6 +275,22 @@ def default(obj: Any) -> Any:
275275
raise TypeError(f"Serialization of {type(obj)} is not supported.")
276276

277277

278+
def __save(obj):
279+
"""Calls the legacy save method to save the object to temp json
280+
then load it into a dictionary.
281+
"""
282+
try:
283+
temp_file = tempfile.NamedTemporaryFile(
284+
mode="w", encoding="utf-8", suffix=".json", delete=False
285+
)
286+
temp_file.close()
287+
obj.save(temp_file.name)
288+
with open(temp_file.name, "r", encoding="utf-8") as f:
289+
return json.load(f)
290+
finally:
291+
os.unlink(temp_file.name)
292+
293+
278294
def dump(obj: Any) -> Dict[str, Any]:
279295
"""Return a json dict representation of an object.
280296
@@ -293,14 +309,14 @@ def dump(obj: Any) -> Dict[str, Any]:
293309
):
294310
# The object is not is_lc_serializable.
295311
# However, it supports the legacy save() method.
296-
try:
297-
temp_file = tempfile.NamedTemporaryFile(
298-
mode="w", encoding="utf-8", suffix=".json", delete=False
299-
)
300-
temp_file.close()
301-
obj.save(temp_file.name)
302-
with open(temp_file.name, "r", encoding="utf-8") as f:
303-
return json.load(f)
304-
finally:
305-
os.unlink(temp_file.name)
306-
return json.loads(json.dumps(obj, default=default))
312+
return __save(obj)
313+
# The object is is_lc_serializable.
314+
# However, some properties may not be serializable
315+
# Here we try to dump the object and fallback to the save() method
316+
# if there is an error.
317+
try:
318+
return json.loads(json.dumps(obj, default=default))
319+
except TypeError as ex:
320+
if isinstance(obj, Serializable) and hasattr(obj, "save"):
321+
return __save(obj)
322+
raise ex

0 commit comments

Comments
 (0)