|
8 | 8 | from typing import Any, Dict, List, Optional
|
9 | 9 |
|
10 | 10 | import requests
|
11 |
| -from oci.auth import signers |
12 | 11 | from langchain.callbacks.manager import CallbackManagerForLLMRun
|
| 12 | +from langchain.pydantic_v1 import root_validator |
| 13 | +from langchain.utils import get_from_dict_or_env |
| 14 | +from oci.auth import signers |
13 | 15 |
|
14 | 16 | from ads.llm.langchain.plugins.base import BaseLLM
|
15 | 17 | from ads.llm.langchain.plugins.contant import (
|
|
23 | 25 | class ModelDeploymentLLM(BaseLLM):
|
24 | 26 | """Base class for LLM deployed on OCI Model Deployment."""
|
25 | 27 |
|
26 |
| - endpoint: str |
| 28 | + endpoint: str = "" |
27 | 29 | """The uri of the endpoint from the deployed Model Deployment model."""
|
28 | 30 |
|
29 | 31 | best_of: int = 1
|
30 | 32 | """Generates best_of completions server-side and returns the "best"
|
31 | 33 | (the one with the highest log probability per token).
|
32 | 34 | """
|
33 | 35 |
|
| 36 | + @root_validator() |
| 37 | + def validate_environment( # pylint: disable=no-self-argument |
| 38 | + cls, values: Dict |
| 39 | + ) -> Dict: |
| 40 | + """Fetch endpoint from environment variable or arguments.""" |
| 41 | + values["endpoint"] = get_from_dict_or_env( |
| 42 | + values, |
| 43 | + "endpoint", |
| 44 | + "OCI_LLM_ENDPOINT", |
| 45 | + ) |
| 46 | + return values |
| 47 | + |
34 | 48 | @property
|
35 | 49 | def _default_params(self) -> Dict[str, Any]:
|
36 | 50 | """Default parameters for the model."""
|
@@ -73,7 +87,7 @@ def _call(
|
73 | 87 | run_manager: Optional[CallbackManagerForLLMRun] = None,
|
74 | 88 | **kwargs: Any,
|
75 | 89 | ) -> str:
|
76 |
| - """Call out to OCI Data Science Model Deployment TGI endpoint. |
| 90 | + """Call out to OCI Data Science Model Deployment endpoint. |
77 | 91 |
|
78 | 92 | Parameters
|
79 | 93 | ----------
|
@@ -203,8 +217,11 @@ class ModelDeploymentTGI(ModelDeploymentLLM):
|
203 | 217 | """
|
204 | 218 |
|
205 | 219 | watermark = True
|
| 220 | + """Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_. |
| 221 | + Defaults to True.""" |
206 | 222 |
|
207 | 223 | return_full_text = False
|
| 224 | + """Whether to prepend the prompt to the generated text. Defaults to False.""" |
208 | 225 |
|
209 | 226 | @property
|
210 | 227 | def _llm_type(self) -> str:
|
@@ -241,6 +258,7 @@ class ModelDeploymentVLLM(ModelDeploymentLLM):
|
241 | 258 | """VLLM deployed on OCI Model Deployment"""
|
242 | 259 |
|
243 | 260 | model: str
|
| 261 | + """Name of the model.""" |
244 | 262 |
|
245 | 263 | n: int = 1
|
246 | 264 | """Number of output sequences to return for the given prompt."""
|
|
0 commit comments