Skip to content

Commit 3d86958

Browse files
authored
Set endpoint through env var (#474)
1 parent c171074 commit 3d86958

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

ads/llm/langchain/plugins/llm_md.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from typing import Any, Dict, List, Optional
99

1010
import requests
11-
from oci.auth import signers
1211
from langchain.callbacks.manager import CallbackManagerForLLMRun
12+
from langchain.pydantic_v1 import root_validator
13+
from langchain.utils import get_from_dict_or_env
14+
from oci.auth import signers
1315

1416
from ads.llm.langchain.plugins.base import BaseLLM
1517
from ads.llm.langchain.plugins.contant import (
@@ -23,14 +25,26 @@
2325
class ModelDeploymentLLM(BaseLLM):
2426
"""Base class for LLM deployed on OCI Model Deployment."""
2527

26-
endpoint: str
28+
endpoint: str = ""
2729
"""The uri of the endpoint from the deployed Model Deployment model."""
2830

2931
best_of: int = 1
3032
"""Generates best_of completions server-side and returns the "best"
3133
(the one with the highest log probability per token).
3234
"""
3335

36+
@root_validator()
37+
def validate_environment( # pylint: disable=no-self-argument
38+
cls, values: Dict
39+
) -> Dict:
40+
"""Fetch endpoint from environment variable or arguments."""
41+
values["endpoint"] = get_from_dict_or_env(
42+
values,
43+
"endpoint",
44+
"OCI_LLM_ENDPOINT",
45+
)
46+
return values
47+
3448
@property
3549
def _default_params(self) -> Dict[str, Any]:
3650
"""Default parameters for the model."""
@@ -73,7 +87,7 @@ def _call(
7387
run_manager: Optional[CallbackManagerForLLMRun] = None,
7488
**kwargs: Any,
7589
) -> str:
76-
"""Call out to OCI Data Science Model Deployment TGI endpoint.
90+
"""Call out to OCI Data Science Model Deployment endpoint.
7791
7892
Parameters
7993
----------
@@ -203,8 +217,11 @@ class ModelDeploymentTGI(ModelDeploymentLLM):
203217
"""
204218

205219
watermark = True
220+
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
221+
Defaults to True."""
206222

207223
return_full_text = False
224+
"""Whether to prepend the prompt to the generated text. Defaults to False."""
208225

209226
@property
210227
def _llm_type(self) -> str:
@@ -241,6 +258,7 @@ class ModelDeploymentVLLM(ModelDeploymentLLM):
241258
"""VLLM deployed on OCI Model Deployment"""
242259

243260
model: str
261+
"""Name of the model."""
244262

245263
n: int = 1
246264
"""Number of output sequences to return for the given prompt."""

0 commit comments

Comments
 (0)