Skip to content

Commit 1d4d9be

Browse files
committed
add codemodels.py
1 parent f15bf11 commit 1d4d9be

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

WebUI/configs/codemodels.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
2+
import json
3+
import asyncio
4+
import base64
5+
import threading
6+
from WebUI.configs.basicconfig import *
7+
from fastapi.responses import StreamingResponse
8+
from WebUI.configs.webuiconfig import InnerJsonConfigWebUIParse
9+
from WebUI.Server.db.repository import add_chat_history_to_db, update_chat_history
10+
from WebUI.Server.chat.StreamHandler import StreamSpeakHandler
11+
from WebUI.Server.utils import FastAPI
12+
13+
def load_causallm_model(app: FastAPI, model_name, model_path, device):
14+
from transformers import AutoModelForCausalLM, AutoTokenizer
15+
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
16+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
17+
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype="auto", device_map=device)
18+
app._model = model
19+
app._tokenizer = tokenizer
20+
app._model_name = model_name
21+
22+
def load_llama_model(app: FastAPI, model_name, model_path, device):
23+
import torch
24+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
25+
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
26+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
27+
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device)
28+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
29+
pipe = pipeline(
30+
"text-generation",
31+
model=model,
32+
tokenizer=tokenizer,
33+
temperature=0.2,
34+
torch_dtype=torch.float16,
35+
streamer=streamer
36+
)
37+
# pipe.model.config.pad_token_id = pipe.model.config.eos_token_id
38+
# llm_model = HuggingFacePipeline(pipeline=pipe)
39+
app._model = pipe
40+
app._streamer = streamer
41+
app._tokenizer = tokenizer
42+
app._model_name = model_name
43+
44+
def init_code_models(app: FastAPI, args):
45+
model_name = args.model_names[0]
46+
model_path = args.model_path
47+
if len(model_name) == 0 or len(model_path) == 0:
48+
return
49+
configinst = InnerJsonConfigWebUIParse()
50+
webui_config = configinst.dump()
51+
model_info : Dict[str, any] = {"mtype": ModelType.Unknown, "msize": ModelSize.Unknown, "msubtype": ModelSubType.Unknown, "mname": str, "config": dict}
52+
model_info["mtype"], model_info["msize"], model_info["msubtype"] = GetModelInfoByName(webui_config, model_name)
53+
model_info["mname"] = model_name
54+
model_config = GetModelConfig(webui_config, model_info)
55+
load_type = model_config.get("load_type", "")
56+
if load_type == "causallm":
57+
load_causallm_model(app=app, model_name=model_name, model_path=model_path, device=args.device)
58+
elif load_type == "llama":
59+
load_llama_model(app=app, model_name=model_name, model_path=model_path, device=args.device)
60+
61+
def code_model_chat(
62+
model: Any,
63+
tokenizer: Any,
64+
async_callback: Any,
65+
modelinfo: Any,
66+
query: str,
67+
imagesdata: List[str],
68+
audiosdata: List[str],
69+
videosdata: List[str],
70+
imagesprompt: List[str],
71+
history: List[dict],
72+
stream: bool,
73+
speechmodel: dict,
74+
temperature: float,
75+
max_tokens: Optional[int],
76+
prompt_name: str,
77+
):
78+
if modelinfo == None:
79+
return json.dumps(
80+
{"text": "Unusual error!", "chat_history_id": 123},
81+
ensure_ascii=False)
82+
83+
async def code_chat_iterator(model: Any,
84+
tokenizer: Any,
85+
async_callback: Any,
86+
query: str,
87+
imagesdata: List[str],
88+
audiosdata: List[str],
89+
videosdata: List[str],
90+
imagesprompt: List[str],
91+
history: List[dict] = [],
92+
modelinfo: Any = None,
93+
temperature: float = 0.7,
94+
max_tokens: Optional[int] = 2048,
95+
prompt_name: str = prompt_name,
96+
) -> AsyncIterable[str]:
97+
98+
from WebUI.Server.utils import detect_device
99+
configinst = InnerJsonConfigWebUIParse()
100+
webui_config = configinst.dump()
101+
model_name = modelinfo["mname"]
102+
speak_handler = None
103+
if len(speechmodel):
104+
modeltype = speechmodel.get("type", "")
105+
provider = speechmodel.get("provider", "")
106+
#spmodel = speechmodel.get("model", "")
107+
spspeaker = speechmodel.get("speaker", "")
108+
speechkey = speechmodel.get("speech_key", "")
109+
speechregion = speechmodel.get("speech_region", "")
110+
if modeltype == "local" or modeltype == "cloud":
111+
speak_handler = StreamSpeakHandler(run_place=modeltype, provider=provider, synthesis=spspeaker, subscription=speechkey, region=speechregion)
112+
113+
answer = ""
114+
chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query)
115+
modelconfig = GetModelConfig(webui_config, modelinfo)
116+
device = modelconfig.get("device", "auto")
117+
device = "cuda" if device == "gpu" else detect_device() if device == "auto" else device
118+
if model_name == "stable-code-3b":
119+
if max_tokens is None:
120+
max_tokens = 512
121+
inputs = tokenizer(query, return_tensors="pt").to(model.device)
122+
tokens = model.generate(
123+
**inputs,
124+
max_new_tokens=max_tokens,
125+
temperature=0.2,
126+
do_sample=True,
127+
)
128+
answer = tokenizer.decode(tokens[0], skip_special_tokens=True)
129+
#answer = answer.replace(query, '').strip()
130+
#while(answer.startswith(("'", '"', ' ', ',', '.', '!', '?'))):
131+
# answer = answer[1:].strip()
132+
answer = "```python\n" + answer + "\n```"
133+
if speak_handler: speak_handler.on_llm_new_token(answer)
134+
yield json.dumps(
135+
{"text": answer, "chat_history_id": chat_history_id},
136+
ensure_ascii=False)
137+
await asyncio.sleep(0.1)
138+
if speak_handler: speak_handler.on_llm_end(None)
139+
140+
elif model_name == "CodeLlama-7b-Python-hf" or \
141+
model_name == "CodeLlama-13b-Python-hf" or \
142+
model_name == "CodeLlama-7b-Instruct-hf" or \
143+
model_name == "CodeLlama-13b-Instruct-hf":
144+
sequences = model(
145+
query,
146+
#do_sample=True,
147+
temperature=0.2,
148+
top_p=0.95,
149+
num_return_sequences=1,
150+
eos_token_id=tokenizer.eos_token_id,
151+
max_length=512,
152+
)
153+
answer = ""
154+
yield json.dumps(
155+
{"text": "```python\n", "chat_history_id": chat_history_id},
156+
ensure_ascii=False)
157+
for seq in sequences:
158+
yield json.dumps(
159+
{"text": seq['generated_text'], "chat_history_id": chat_history_id},
160+
ensure_ascii=False)
161+
await asyncio.sleep(0.1)
162+
yield json.dumps(
163+
{"text": "\n```", "chat_history_id": chat_history_id},
164+
ensure_ascii=False)
165+
if speak_handler:
166+
speak_handler.on_llm_new_token(answer)
167+
speak_handler.on_llm_end(None)
168+
169+
update_chat_history(chat_history_id, response=answer)
170+
171+
return StreamingResponse(code_chat_iterator(
172+
model=model,
173+
tokenizer=tokenizer,
174+
async_callback=async_callback,
175+
query=query,
176+
imagesdata=imagesdata,
177+
audiosdata=audiosdata,
178+
videosdata=videosdata,
179+
imagesprompt=imagesprompt,
180+
history=history,
181+
modelinfo=modelinfo,
182+
temperature=temperature,
183+
max_tokens=max_tokens,
184+
prompt_name=prompt_name),
185+
media_type="text/event-stream")

0 commit comments

Comments
 (0)