Skip to content

WIP on CAII Support #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
id: create_release
with:
draft: false
prerelease: ${{ github.event.inputs.BRANCH == 'mob/main' }}
prerelease: ${{ github.event.inputs.BRANCH != 'release/1' }}
name: ${{ github.event.inputs.VERSION }}
tag_name: ${{ github.event.inputs.VERSION }}
files: |
Expand Down
12 changes: 12 additions & 0 deletions .project-metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ environment_variables:
default: ""
description: "AWS Secret Access Key"
required: true
CAII_DOMAIN:
default: ""
description: "The domain of the CAII service. Setting this will enable CAII as the sole source for both inference and embedding models."
required: false
CAII_INFERENCE_ENDPOINT_NAME:
default: ""
description: "The name of the inference endpoint for the CAII service. Required if CAII_DOMAIN is set."
required: false
CAII_EMBEDDING_ENDPOINT_NAME:
default: ""
description: "The name of the embedding endpoint for the CAII service. Required if CAII_DOMAIN is set."
required: false
DB_URL:
default: "jdbc:h2:file:~/databases/rag"
description: "Internal DB URL. Do not change."
Expand Down
7 changes: 6 additions & 1 deletion llm-service/app/routers/index/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#
from fastapi import APIRouter
from .... import exceptions
from ....services.models import get_available_embedding_models, get_available_llm_models
from ....services.models import get_available_embedding_models, get_available_llm_models, get_model_source, ModelSource

router = APIRouter(prefix="/models")

Expand All @@ -50,3 +50,8 @@ def get_llm_models() -> list:
@exceptions.propagates
def get_llm_embedding_models() -> list:
return get_available_embedding_models()

@router.get("/model_source", summary="Model source enabled - Bedrock or CAII")
@exceptions.propagates
def get_model() -> ModelSource:
return get_model_source()
18 changes: 16 additions & 2 deletions llm-service/app/services/caii.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import json
import os

from fastapi import HTTPException
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms import LLM

Expand All @@ -60,6 +61,9 @@ def describe_endpoint(domain: str, endpoint_name: str):
}

desc = requests.post(describe_url, headers=headers, json=desc_json)
if desc.status_code == 404:
raise HTTPException(status_code=404, detail = f"Endpoint '{endpoint_name}' not found")
print(desc.content)
content = json.loads(desc.content)
return content

Expand Down Expand Up @@ -92,14 +96,24 @@ def get_embedding_model() -> BaseEmbedding:
endpoint = describe_endpoint(domain=domain, endpoint_name=endpoint_name)
return CaiiEmbeddingModel(endpoint=endpoint)

### metadata methods below here

def get_caii_llm_models():
domain = os.environ['CAII_DOMAIN']
endpoint_name = os.environ['CAII_INFERENCE_ENDPOINT_NAME']
models = describe_endpoint(domain=domain, endpoint_name=endpoint_name)
return [{ "model_id": models["name"], "name": models["name"] }]
return build_model_response(models)

def get_caii_embedding_models():
# notes:
# NameResolutionError is we can't contact the CAII_DOMAIN
# HTTPException (404) is we can't find the endpoint by name

domain = os.environ['CAII_DOMAIN']
endpoint_name = os.environ['CAII_EMBEDDING_ENDPOINT_NAME']
models = describe_endpoint(domain=domain, endpoint_name=endpoint_name)
return [{ "model_id": models["name"], "name": models["name"] }]
return build_model_response(models)

def build_model_response(models):
return [{"model_id": models["name"], "name": models["name"], "available": models['replica_count'] > 0,
"replica_count": models["replica_count"]}]
9 changes: 9 additions & 0 deletions llm-service/app/services/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# DATA.
#
import os
from enum import Enum

from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms import LLM
Expand Down Expand Up @@ -97,3 +98,11 @@ def _get_bedrock_embedding_models():
"name": "cohere.embed-english-v3",
}]

class ModelSource(str, Enum):
BEDROCK = "Bedrock"
CAII = "CAII"

def get_model_source() -> ModelSource:
if "CAII_DOMAIN" in os.environ:
return ModelSource.CAII
return ModelSource.BEDROCK
17 changes: 16 additions & 1 deletion ui/src/api/modelsApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
******************************************************************************/
import { useQuery } from "@tanstack/react-query";
import { queryOptions, useQuery } from "@tanstack/react-query";
import { getRequest, llmServicePath, QueryKeys } from "src/api/utils.ts";

export interface Model {
name: string;
model_id: string;
available?: boolean;
replica_count?: number;
}

export const useGetLlmModels = () => {
Expand Down Expand Up @@ -68,3 +70,16 @@ export const useGetEmbeddingModels = () => {
const getEmbeddingModels = async (): Promise<Model[]> => {
return await getRequest(`${llmServicePath}/index/models/embeddings`);
};

type ModelSource = "CAII" | "Bedrock";

export const getModelSourceQueryOptions = queryOptions({
queryKey: [QueryKeys.getModelSource],
queryFn: async () => {
return await getModelSource();
},
});

const getModelSource = async (): Promise<ModelSource> => {
return await getRequest(`${llmServicePath}/index/models/model_source`);
};
1 change: 1 addition & 0 deletions ui/src/api/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ export enum QueryKeys {
"getAmpUpdateJobStatus" = "getAmpUpdateJobStatus",
"getLlmModels" = "getLlmModels",
"getEmbeddingModels" = "getEmbeddingModels",
"getModelSource" = "getModelSource",
}

export const commonHeaders = {
Expand Down
24 changes: 17 additions & 7 deletions ui/src/layout/Sidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ const Sidebar: React.FC = () => {
feedbackModal.setIsModalOpen(true);
};

const items: MenuItem[] = [
const baseItems: MenuItem[] = [
{
label: (
<Tag
Expand Down Expand Up @@ -133,14 +133,24 @@ const Sidebar: React.FC = () => {
navToData,
<DatabaseFilled />,
),
getItem(
<div data-testid="data-management-nav">Leave Feedback</div>,
"leave-feedback",
popupFeedback,
<ThumbUpIcon />,
),
];

// const caiiModels = getItem(
// <div data-testid="data-management-nav">CAII Models</div>,
// "caii-models",
// navToData,
// <DatabaseFilled />,
// )

const feedbackItem = getItem(
<div data-testid="data-management-nav">Leave Feedback</div>,
"leave-feedback",
popupFeedback,
<ThumbUpIcon />,
);

const items = [...baseItems, feedbackItem];

function chooseRoute() {
if (matchRoute({ to: "/data", fuzzy: true })) {
return ["data"];
Expand Down
2 changes: 1 addition & 1 deletion ui/src/routeTree.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const LayoutDataImport = createFileRoute('/_layout/data')()
const LayoutRoute = LayoutImport.update({
id: '/_layout',
getParentRoute: () => rootRoute,
} as any)
} as any).lazy(() => import('./routes/_layout.lazy').then((d) => d.Route))

const IndexRoute = IndexImport.update({
id: '/',
Expand Down
54 changes: 54 additions & 0 deletions ui/src/routes/_layout.lazy.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*******************************************************************************
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
******************************************************************************/

import { createLazyFileRoute, Outlet } from "@tanstack/react-router";
import { Layout } from "antd";
import Sidebar from "src/layout/Sidebar.tsx";

const { Content } = Layout;

export const Route = createLazyFileRoute("/_layout")({
component: () => (
<Layout style={{ minHeight: "100vh" }}>
<Sidebar />
<Content style={{ margin: "0", overflowY: "auto" }}>
<Outlet />
</Content>
</Layout>
),
});
17 changes: 4 additions & 13 deletions ui/src/routes/_layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,10 @@
* DATA.
******************************************************************************/

import { createFileRoute, Outlet } from "@tanstack/react-router";
import { Layout } from "antd";
import Sidebar from "src/layout/Sidebar.tsx";

const { Content } = Layout;
import { createFileRoute } from "@tanstack/react-router";
import { getModelSourceQueryOptions } from "src/api/modelsApi.ts";

export const Route = createFileRoute("/_layout")({
component: () => (
<Layout style={{ minHeight: "100vh" }}>
<Sidebar />
<Content style={{ margin: "0", overflowY: "auto" }}>
<Outlet />
</Content>
</Layout>
),
loader: async ({ context: { queryClient } }) =>
queryClient.ensureQueryData(getModelSourceQueryOptions),
});