Skip to content

Commit df99a96

Browse files
OCI GenAI support for chat & embeddings (#77)
* OCI GenAI support for chat & embeddings OCI GenAI support with: chat: cohere.command-r-plus-08-2024 embeddings: cohere.embed-multilingual-v3.0 * fix to force index creation for any vectorstore * fix OCI GenAI chatbot ---------
1 parent e1d1d39 commit df99a96

File tree

4 files changed

+56
-1
lines changed

4 files changed

+56
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ temp/tools.ipynb
4444
temp/tools.py
4545
temp/json-dual.sql
4646
env.sh
47+
temp/oci_genai.py

app/src/content/oci_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def main():
8989
key="text_input_region",
9090
)
9191
key_file = st.text_input("Key File:", value=state.oci_config["key_file"], key="text_input_key_file")
92+
compartment_id = st.text_input("Compartment ID:", value=state.oci_config["compartment_id"], key="text_input_compartment_id")
9293

9394
if st.form_submit_button(label="Save"):
9495
print("I'm Here!")

app/src/modules/metadata.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from langchain_openai import OpenAIEmbeddings
1818
from langchain_ollama import OllamaEmbeddings
1919
from langchain_cohere import CohereEmbeddings
20+
from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings
21+
2022
from langchain.retrievers.document_compressors import CohereRerank
2123

2224
logger = logging_config.logging.getLogger("modules.metadata")
@@ -87,6 +89,21 @@ def ll_models():
8789
"""Define example Language Model Support"""
8890
# Lists are in [user, default, min, max] format
8991
ll_models_dict = {
92+
93+
"cohere.command-r-plus-08-2024": {
94+
"enabled": os.getenv("OCI_PROFILE") is not None,
95+
"api": "CohereOCI",
96+
"url": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
97+
"api_key": os.environ.get("OCI_PROFILE", default=""),
98+
"openai_compat": False,
99+
"context_length": 131072 ,
100+
"temperature": [0.3, 0.3, 0.0, 1.0],
101+
"top_p": [1.0, 1.0, 0.0, 1.0],
102+
"max_tokens": [100, 100, 1, 4096],
103+
"frequency_penalty": [0.0, 0.0, -1.0, 1.0],
104+
"presence_penalty": [0.0, 0.0, -2.0, 2.0],
105+
},
106+
90107
"command-r": {
91108
"enabled": os.getenv("COHERE_API_KEY") is not None,
92109
"api": "Cohere",
@@ -279,6 +296,15 @@ def embedding_models():
279296
"chunk_max": 8191,
280297
"dimensions": 1536,
281298
},
299+
"cohere.embed-multilingual-v3.0":{
300+
"enabled": os.getenv("OCI_PROFILE") is not None,
301+
"api": OCIGenAIEmbeddings,
302+
"url": "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
303+
"api_key": os.environ.get("OCI_PROFILE", default=""),
304+
"openai_compat": False,
305+
"chunk_max": 512,
306+
"dimensions": 1024,
307+
},
282308
"embed-english-v3.0": {
283309
"enabled": os.getenv("COHERE_API_KEY") is not None,
284310
"api": CohereEmbeddings,

app/src/modules/utilities.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from langchain_ollama import ChatOllama
3333
from langchain_openai import ChatOpenAI
3434

35+
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
36+
37+
3538
from llama_index.core import Document
3639
from llama_index.core.node_parser import SentenceSplitter
3740

@@ -102,7 +105,8 @@ def get_ll_model(model, ll_models_config=None, giskarded=False):
102105
"frequency_penalty": lm_params["frequency_penalty"][0],
103106
"presence_penalty": lm_params["presence_penalty"][0],
104107
}
105-
108+
logger.info("LLM_API:"+llm_api)
109+
106110
## Start - Add Additional Model Authentication Here
107111
client = None
108112
if giskarded:
@@ -117,6 +121,17 @@ def get_ll_model(model, ll_models_config=None, giskarded=False):
117121
client = ChatPerplexity(pplx_api_key=lm_params["api_key"], model_kwargs=common_params)
118122
elif llm_api == "ChatOllama":
119123
client = ChatOllama(model=model,base_url=lm_params["url"], model_kwargs=common_params)
124+
elif llm_api == "CohereOCI":
125+
#state.oci_config["tenancy_ocid"]
126+
client = ChatOCIGenAI(
127+
model_id=model,
128+
service_endpoint=lm_params["url"],
129+
compartment_id=os.environ.get("OCI_COMPARTMENT_ID", default=""),
130+
auth_profile=lm_params["api_key"],
131+
model_kwargs={"temperature": common_params["temperature"], "max_tokens": common_params["max_tokens"],"top_p": common_params["top_p"],
132+
"frequency_penalty": common_params["frequency_penalty"], "presence_penalty": common_params["presence_penalty"]}
133+
)
134+
120135
## End - Add Additional Model Authentication Here
121136
api_accessible, err_msg = is_url_accessible(llm_url)
122137

@@ -134,6 +149,7 @@ def get_embedding_model(model, embed_model_config=None, giskarded=False):
134149
embed_key = embed_model_config[model]["api_key"]
135150

136151
logger.debug("Matching Embedding API: %s", embed_api)
152+
137153
if giskarded:
138154
giskard_key = embed_key or "giskard"
139155
_client = OpenAI(api_key=giskard_key, base_url=f"{embed_url}/v1/")
@@ -148,6 +164,9 @@ def get_embedding_model(model, embed_model_config=None, giskarded=False):
148164
client = embed_api(model=model, base_url=embed_url)
149165
elif embed_api.__name__ == "CohereEmbeddings":
150166
client = embed_api(model=model, cohere_api_key=embed_key)
167+
elif embed_api.__name__ == "OCIGenAIEmbeddings":
168+
client = embed_api(model_id=model, service_endpoint=embed_url, compartment_id= os.environ.get("OCI_COMPARTMENT_ID", default=""), auth_profile=embed_key)
169+
151170
else:
152171
client = embed_api(model=embed_url)
153172

@@ -397,6 +416,10 @@ def json_to_doc(file: str):
397416

398417
execute_sql(db_conn, mergesql)
399418
db_conn.commit()
419+
#NOTE: In this release, index is automatically created without user control. This part of code helps the future release
420+
#to re-create an index or leave without an existing vectorstore.
421+
#for this reason the index_exists is set to True to recreate in any case the index.
422+
index_exists=True
400423

401424
if (index_exists):
402425
# Build the Index
@@ -498,6 +521,7 @@ def oci_initialize(
498521
region=None,
499522
key_file=None,
500523
security_token_file=None,
524+
compartment_id=None,
501525
):
502526
"""Initialize the configuration for OCI AuthN"""
503527
config = {
@@ -510,10 +534,12 @@ def oci_initialize(
510534
"additional_user_agent": "",
511535
"log_requests": False,
512536
"pass_phrase": None,
537+
"compartment_id":compartment_id,
513538
}
514539

515540
config_file = os.environ.get("OCI_CLI_CONFIG_FILE", default=oci.config.DEFAULT_LOCATION)
516541
config_profile = os.environ.get("OCI_CLI_PROFILE", default=oci.config.DEFAULT_PROFILE)
542+
compartment_id = os.environ.get("OCI_COMPARTMENT_ID", default="")
517543

518544
# Ingest config file when parameter are missing
519545
if not (fingerprint and tenancy and region and key_file and (user or security_token_file)):
@@ -529,6 +555,7 @@ def oci_initialize(
529555
config["tenancy"] = os.environ.get("OCI_CLI_TENANCY", config.get("tenancy"))
530556
config["region"] = os.environ.get("OCI_CLI_REGION", config.get("region"))
531557
config["key_file"] = os.environ.get("OCI_CLI_KEY_FILE", config.get("key_file"))
558+
config["compartment_id"] = os.environ.get("OCI_COMPARTMENT_ID", config.get("compartment_id"))
532559
return config
533560

534561

0 commit comments

Comments
 (0)