Skip to content

Commit 3f4c91a

Browse files
jkwatsonewilliams-clouderamliu-clouderabaasitshariefCopilot
authored
OpenSearch support (#223)
* add http auth for the low level client as well * wip on chunk retrieval * implement raw query to return the node contents for OpenSearch * don't refetch chat history on window focus to allow for continued streaming * add in oss username and password * drop databases lastFile:llm-service/app/services/amp_metadata/__init__.py * create opensearch config for ui and backend * wip lastFile:ui/src/pages/Settings/VectorDBFields.tsx * drop databases lastFile:ui/src/pages/Settings/VectorDBFields.tsx * WIP vector db provider lastFile:ui/src/pages/Settings/VectorDBFields.tsx * cleanup opensearch config setting * add namespace for opensearch * Handle 404s properly with opensearch deletions & other possibly failed deletion * fix a small mypy issue * add beta tag * Update ui/src/pages/RagChatTab/SessionsSidebar/CreateSession/CreateSessionForm.tsx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * handle the case where we can't find the chunk * use typeadapters * log query --------- Co-authored-by: Elijah Williams <ewilliams@cloudera.com> Co-authored-by: Michael Liu <mliu@cloudera.com> Co-authored-by: Baasit Sharief <baasitsharief@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 04d4a04 commit 3f4c91a

File tree

17 files changed

+345
-31
lines changed

17 files changed

+345
-31
lines changed

.env.example

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ OPENAI_API_VERSION=
1313
# QDRANT or OPENSEARCH
1414
VECTOR_DB_PROVIDER=QDRANT
1515

16+
# OpenSearch
17+
OPENSEARCH_ENDPOINT=
18+
OPENSEARCH_USERNAME=
19+
OPENSEARCH_PASSWORD=
20+
OPENSEARCH_NAMESPACE=
21+
1622
# AWS
1723
AWS_ACCESS_KEY_ID=
1824
AWS_SECRET_ACCESS_KEY=

backend/src/main/java/com/cloudera/cai/rag/external/RagBackendClient.java

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import com.cloudera.cai.util.Tracker;
4646
import com.cloudera.cai.util.exceptions.ClientError;
4747
import com.cloudera.cai.util.exceptions.HttpError;
48+
import com.cloudera.cai.util.exceptions.NotFound;
4849
import com.cloudera.cai.util.exceptions.ServerError;
4950
import com.fasterxml.jackson.annotation.JsonProperty;
5051
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -53,9 +54,11 @@
5354
import java.io.IOException;
5455
import java.util.Arrays;
5556
import java.util.List;
57+
import lombok.extern.slf4j.Slf4j;
5658
import org.springframework.beans.factory.annotation.Autowired;
5759
import org.springframework.stereotype.Component;
5860

61+
@Slf4j
5962
@Component
6063
public class RagBackendClient {
6164
private static final String AUTH_TOKEN = System.getenv("CDSW_APIV2_KEY");
@@ -125,22 +128,34 @@ public String createSummary(Types.RagDocument ragDocument, String bucketName) {
125128
}
126129

127130
public void deleteDataSource(Long dataSourceId) {
128-
client.delete(
129-
getLlmServiceUrl() + "/data_sources/" + dataSourceId,
130-
"Authorization",
131-
"Bearer " + AUTH_TOKEN);
131+
try {
132+
client.delete(
133+
getLlmServiceUrl() + "/data_sources/" + dataSourceId,
134+
"Authorization",
135+
"Bearer " + AUTH_TOKEN);
136+
} catch (NotFound e) {
137+
log.info("Data source not found. Deletion not necessary.");
138+
}
132139
}
133140

134141
public void deleteDocument(long dataSourceId, String documentId) {
135-
client.delete(
136-
getLlmServiceUrl() + "/data_sources/" + dataSourceId + "/documents/" + documentId,
137-
"Authorization",
138-
"Bearer " + AUTH_TOKEN);
142+
try {
143+
client.delete(
144+
getLlmServiceUrl() + "/data_sources/" + dataSourceId + "/documents/" + documentId,
145+
"Authorization",
146+
"Bearer " + AUTH_TOKEN);
147+
} catch (NotFound e) {
148+
log.info("Document not found. Deletion not necessary.");
149+
}
139150
}
140151

141152
public void deleteSession(Long sessionId) {
142-
client.delete(
143-
getLlmServiceUrl() + "/sessions/" + sessionId, "Authorization", "Bearer " + AUTH_TOKEN);
153+
try {
154+
client.delete(
155+
getLlmServiceUrl() + "/sessions/" + sessionId, "Authorization", "Bearer " + AUTH_TOKEN);
156+
} catch (NotFound e) {
157+
log.info("Session not found. Deletion not necessary.");
158+
}
144159
}
145160

146161
record IndexRequest(

backend/src/main/java/com/cloudera/cai/util/SimpleHttpClient.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ public void delete(String path, String... headers) {
109109
HttpResponse<String> response =
110110
httpClient.send(request, HttpResponse.BodyHandlers.ofString());
111111
int statusCode = response.statusCode();
112+
if (statusCode == 404) {
113+
throw new NotFound("Failed to delete. Not Found");
114+
}
115+
112116
if (statusCode >= 400) {
113117
throw new RuntimeException(
114118
"Failed to delete " + path + " code: " + statusCode + ", body : " + response.body());

backend/src/test/java/com/cloudera/cai/rag/external/RagBackendClientTest.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,20 @@ void deleteDataSource() {
213213
HttpMethod.DELETE, "http://localhost:8081/data_sources/1234", null));
214214
}
215215

216+
@Test
217+
void deleteDataSource_notFound() {
218+
Tracker<TrackedHttpRequest<?>> tracker = new Tracker<>();
219+
RagBackendClient client =
220+
new RagBackendClient(SimpleHttpClient.createNull(tracker, new NotFound("Not found!")));
221+
client.deleteDataSource(1234L);
222+
List<TrackedHttpRequest<?>> values = tracker.getValues();
223+
assertThat(values)
224+
.hasSize(1)
225+
.contains(
226+
new TrackedHttpRequest<>(
227+
HttpMethod.DELETE, "http://localhost:8081/data_sources/1234", null));
228+
}
229+
216230
@Test
217231
void deleteDocument() {
218232
Tracker<TrackedHttpRequest<?>> tracker = new Tracker<>();
@@ -228,6 +242,22 @@ void deleteDocument() {
228242
null));
229243
}
230244

245+
@Test
246+
void deleteDocument_notFound() {
247+
Tracker<TrackedHttpRequest<?>> tracker = new Tracker<>();
248+
RagBackendClient client =
249+
new RagBackendClient(SimpleHttpClient.createNull(tracker, new NotFound("Not found!")));
250+
client.deleteDocument(1234L, "documentId");
251+
List<TrackedHttpRequest<?>> values = tracker.getValues();
252+
assertThat(values)
253+
.hasSize(1)
254+
.contains(
255+
new TrackedHttpRequest<>(
256+
HttpMethod.DELETE,
257+
"http://localhost:8081/data_sources/1234/documents/documentId",
258+
null));
259+
}
260+
231261
@Test
232262
void deleteSession() {
233263
Tracker<TrackedHttpRequest<?>> tracker = new Tracker<>();
@@ -241,6 +271,20 @@ void deleteSession() {
241271
HttpMethod.DELETE, "http://localhost:8081/sessions/1234", null));
242272
}
243273

274+
@Test
275+
void deleteSession_notFound() {
276+
Tracker<TrackedHttpRequest<?>> tracker = new Tracker<>();
277+
RagBackendClient client =
278+
new RagBackendClient(SimpleHttpClient.createNull(tracker, new NotFound("Not found!")));
279+
client.deleteSession(1234L);
280+
List<TrackedHttpRequest<?>> values = tracker.getValues();
281+
assertThat(values)
282+
.hasSize(1)
283+
.contains(
284+
new TrackedHttpRequest<>(
285+
HttpMethod.DELETE, "http://localhost:8081/sessions/1234", null));
286+
}
287+
244288
@Test
245289
void null_handlesThrowable() {
246290
RagBackendClient client =

llm-service/app/ai/vector_stores/opensearch.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
from abc import ABC
4141
from typing import Optional, List
4242

43+
import fastapi.exceptions
44+
import opensearchpy
4345
from llama_index.core.base.embeddings.base import BaseEmbedding
46+
from llama_index.core.schema import BaseNode, TextNode
4447
from llama_index.core.vector_stores.types import BasePydanticVectorStore
4548
from llama_index.vector_stores.opensearch import (
4649
OpensearchVectorStore,
@@ -55,18 +58,21 @@
5558

5659
logger = logging.getLogger(__name__)
5760

61+
5862
def _new_opensearch_client(dim: int, index: str) -> OpensearchVectorClient:
5963
return OpensearchVectorClient(
60-
# username=os.environ.get("OPENSEARCH_USERNAME", "admin"),
61-
# password=os.environ.get("OPENSEARCH_INITIAL_ADMIN_PASSWORD"),
6264
endpoint=settings.opensearch_endpoint,
6365
index=index,
6466
dim=dim,
67+
http_auth=(settings.opensearch_username, settings.opensearch_password),
6568
)
6669

6770

6871
def _get_low_level_client() -> OpensearchClient:
69-
os_client = OpensearchClient(settings.opensearch_endpoint)
72+
os_client = OpensearchClient(
73+
settings.opensearch_endpoint,
74+
http_auth=(settings.opensearch_username, settings.opensearch_password),
75+
)
7076
return os_client
7177

7278

@@ -77,14 +83,14 @@ class OpenSearch(VectorStore, ABC):
7783
def for_chunks(data_source_id: int) -> "OpenSearch":
7884
return OpenSearch(
7985
data_source_id=data_source_id,
80-
table_name=f"index_{data_source_id}",
86+
table_name=f"{settings.opensearch_namespace}__index_{data_source_id}",
8187
)
8288

8389
@staticmethod
8490
def for_summaries(data_source_id: int) -> "OpenSearch":
8591
return OpenSearch(
8692
data_source_id=data_source_id,
87-
table_name=f"summary_index_{data_source_id}",
93+
table_name=f"{settings.opensearch_namespace}__summary_index_{data_source_id}",
8894
)
8995

9096
def __init__(
@@ -110,7 +116,10 @@ def size(self) -> Optional[int]:
110116

111117
def delete(self) -> None:
112118
os_client = self._low_level_client
113-
os_client.indices.delete(index=self.table_name)
119+
try:
120+
os_client.indices.delete(index=self.table_name)
121+
except opensearchpy.exceptions.NotFoundError:
122+
raise fastapi.exceptions.HTTPException(404, "Index not found")
114123

115124
def delete_document(self, document_id: str) -> None:
116125
self._get_client().delete_by_doc_id(document_id)
@@ -120,6 +129,19 @@ def llama_vector_store(self) -> BasePydanticVectorStore:
120129
self._get_client(),
121130
)
122131

132+
def get_chunk_contents(self, chunk_id: str) -> BaseNode:
133+
query = {"query": {"terms": {"_id": [chunk_id]}}}
134+
raw_results = self._low_level_client.search(index=self.table_name, body=query)
135+
if raw_results["hits"] and raw_results["hits"]["hits"]:
136+
return TextNode(
137+
id_=chunk_id, text=raw_results["hits"]["hits"][0]["_source"]["content"]
138+
)
139+
else:
140+
logger.error(f"Chunk not found for query: {query}")
141+
raise fastapi.exceptions.HTTPException(
142+
404, "Chunk not found with id: " + chunk_id
143+
)
144+
123145
def _get_client(self) -> OpensearchVectorClient:
124146
return _new_opensearch_client(
125147
dim=self._find_dim(self.data_source_id),

llm-service/app/ai/vector_stores/vector_store.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
import umap
4343
from llama_index.core.base.embeddings.base import BaseEmbedding
44+
from llama_index.core.schema import BaseNode
4445
from llama_index.core.vector_stores.types import BasePydanticVectorStore
4546

4647
logger = logging.getLogger(__name__)
@@ -109,3 +110,6 @@ def visualize_embeddings(
109110
# Log the error
110111
logger.error(f"Error during UMAP transformation: {e}")
111112
return []
113+
114+
def get_chunk_contents(self, chunk_id: str) -> BaseNode :
115+
return self.llama_vector_store().get_nodes([chunk_id])[0]

llm-service/app/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
SummaryStorageProviderType = Literal["Local", "S3"]
5353
ChatStoreProviderType = Literal["Local", "S3"]
54+
VectorDbProviderType = Literal["QDRANT", "OPENSEARCH"]
5455

5556

5657
class _Settings:
@@ -108,6 +109,18 @@ def vector_db_provider(self) -> Optional[str]:
108109
def opensearch_endpoint(self) -> str:
109110
return os.environ.get("OPENSEARCH_ENDPOINT", "http://localhost:9200")
110111

112+
@property
113+
def opensearch_namespace(self) -> str:
114+
return os.environ.get("OPENSEARCH_NAMESPACE", "rag_document_index")
115+
116+
@property
117+
def opensearch_username(self) -> str:
118+
return os.environ.get("OPENSEARCH_USERNAME", "")
119+
120+
@property
121+
def opensearch_password(self) -> str:
122+
return os.environ.get("OPENSEARCH_PASSWORD", "")
123+
111124
@property
112125
def document_bucket_prefix(self) -> str:
113126
return os.environ.get("S3_RAG_BUCKET_PREFIX", "")

llm-service/app/routers/index/data_source/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def delete(self, data_source_id: int) -> None:
112112
)
113113
@exceptions.propagates
114114
def chunk_contents(self, chunk_id: str) -> ChunkContentsResponse:
115-
node = self.chunks_vector_store.llama_vector_store().get_nodes([chunk_id])[0]
115+
node = self.chunks_vector_store.get_chunk_contents(chunk_id)
116116
return ChunkContentsResponse(
117117
text=node.get_content(),
118118
metadata=node.metadata,

llm-service/app/services/amp_metadata/__init__.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@
4040
import socket
4141
from typing import Optional, cast, Protocol
4242

43-
from pydantic import BaseModel
43+
from pydantic import BaseModel, TypeAdapter
4444

45-
from app.config import settings, SummaryStorageProviderType, ChatStoreProviderType
45+
from app.config import (
46+
settings,
47+
SummaryStorageProviderType,
48+
ChatStoreProviderType,
49+
VectorDbProviderType,
50+
)
4651

4752

4853
class AwsConfig(BaseModel):
@@ -83,6 +88,15 @@ class OpenAiConfig(BaseModel):
8388
openai_api_key: Optional[str] = None
8489
openai_api_base: Optional[str] = None
8590

91+
92+
class OpenSearchConfig(BaseModel):
93+
94+
opensearch_username: Optional[str] = None
95+
opensearch_password: Optional[str] = None
96+
opensearch_endpoint: Optional[str] = None
97+
opensearch_namespace: Optional[str] = None
98+
99+
86100
class ProjectConfig(BaseModel):
87101
"""
88102
Model to represent the project configuration.
@@ -91,10 +105,12 @@ class ProjectConfig(BaseModel):
91105
use_enhanced_pdf_processing: Optional[bool] = False
92106
summary_storage_provider: SummaryStorageProviderType
93107
chat_store_provider: ChatStoreProviderType
108+
vector_db_provider: VectorDbProviderType
94109
aws_config: AwsConfig
95110
azure_config: AzureConfig
96111
caii_config: CaiiConfig
97112
openai_config: OpenAiConfig
113+
opensearch_config: OpenSearchConfig
98114
cdp_token: Optional[str] = None
99115

100116

@@ -200,6 +216,7 @@ def config_to_env(config: ProjectConfig) -> dict[str, str]:
200216
"USE_ENHANCED_PDF_PROCESSING": str(config.use_enhanced_pdf_processing).lower(),
201217
"SUMMARY_STORAGE_PROVIDER": config.summary_storage_provider or "Local",
202218
"CHAT_STORE_PROVIDER": config.chat_store_provider or "Local",
219+
"VECTOR_DB_PROVIDER": config.vector_db_provider or "QDRANT",
203220
"AWS_DEFAULT_REGION": config.aws_config.region or "",
204221
"S3_RAG_DOCUMENT_BUCKET": config.aws_config.document_bucket_name or "",
205222
"S3_RAG_BUCKET_PREFIX": config.aws_config.bucket_prefix or "",
@@ -209,6 +226,10 @@ def config_to_env(config: ProjectConfig) -> dict[str, str]:
209226
"AZURE_OPENAI_ENDPOINT": config.azure_config.openai_endpoint or "",
210227
"OPENAI_API_VERSION": config.azure_config.openai_api_version or "",
211228
"CAII_DOMAIN": config.caii_config.caii_domain or "",
229+
"OPENSEARCH_USERNAME": config.opensearch_config.opensearch_username or "",
230+
"OPENSEARCH_PASSWORD": config.opensearch_config.opensearch_password or "",
231+
"OPENSEARCH_ENDPOINT": config.opensearch_config.opensearch_endpoint or "",
232+
"OPENSEARCH_NAMESPACE": config.opensearch_config.opensearch_namespace or "",
212233
"OPENAI_API_KEY": config.openai_config.openai_api_key or "",
213234
"OPENAI_API_BASE": config.openai_config.openai_api_base or "",
214235
}
@@ -235,22 +256,33 @@ def build_configuration(
235256
caii_config = CaiiConfig(
236257
caii_domain=env.get("CAII_DOMAIN"),
237258
)
259+
opensearch_config = OpenSearchConfig(
260+
opensearch_username=env.get(
261+
"OPENSEARCH_USERNAME",
262+
),
263+
opensearch_password=env.get("OPENSEARCH_PASSWORD"),
264+
opensearch_endpoint=env.get("OPENSEARCH_ENDPOINT"),
265+
opensearch_namespace=env.get("OPENSEARCH_NAMESPACE"),
266+
)
238267
return ProjectConfigPlus(
239-
use_enhanced_pdf_processing=cast(
240-
bool,
268+
use_enhanced_pdf_processing=TypeAdapter(bool).validate_python(
241269
env.get("USE_ENHANCED_PDF_PROCESSING", False),
242270
),
243-
summary_storage_provider=cast(
244-
SummaryStorageProviderType,
271+
summary_storage_provider=TypeAdapter(
272+
SummaryStorageProviderType
273+
).validate_python(
245274
env.get("SUMMARY_STORAGE_PROVIDER", "Local"),
246275
),
247-
chat_store_provider=cast(
248-
ChatStoreProviderType,
276+
chat_store_provider=TypeAdapter(ChatStoreProviderType).validate_python(
249277
env.get("CHAT_STORE_PROVIDER", "Local"),
250278
),
279+
vector_db_provider=TypeAdapter(VectorDbProviderType).validate_python(
280+
env.get("VECTOR_DB_PROVIDER", "QDRANT")
281+
),
251282
aws_config=aws_config,
252283
azure_config=azure_config,
253284
caii_config=caii_config,
285+
opensearch_config=opensearch_config,
254286
is_valid_config=validate(env),
255287
release_version=os.environ.get("RELEASE_TAG", "unknown"),
256288
application_config=application_config,

0 commit comments

Comments
 (0)