|
| 1 | + |
| 2 | +import json |
| 3 | +import asyncio |
| 4 | +import base64 |
| 5 | +import threading |
| 6 | +from WebUI.configs.basicconfig import * |
| 7 | +from fastapi.responses import StreamingResponse |
| 8 | +from WebUI.configs.webuiconfig import InnerJsonConfigWebUIParse |
| 9 | +from WebUI.Server.db.repository import add_chat_history_to_db, update_chat_history |
| 10 | +from WebUI.Server.chat.StreamHandler import StreamSpeakHandler |
| 11 | +from WebUI.Server.utils import FastAPI |
| 12 | + |
| 13 | +def load_causallm_model(app: FastAPI, model_name, model_path, device): |
| 14 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 15 | + from langchain.llms.huggingface_pipeline import HuggingFacePipeline |
| 16 | + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| 17 | + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype="auto", device_map=device) |
| 18 | + app._model = model |
| 19 | + app._tokenizer = tokenizer |
| 20 | + app._model_name = model_name |
| 21 | + |
| 22 | +def load_llama_model(app: FastAPI, model_name, model_path, device): |
| 23 | + import torch |
| 24 | + from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline |
| 25 | + from langchain.llms.huggingface_pipeline import HuggingFacePipeline |
| 26 | + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| 27 | + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device) |
| 28 | + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| 29 | + pipe = pipeline( |
| 30 | + "text-generation", |
| 31 | + model=model, |
| 32 | + tokenizer=tokenizer, |
| 33 | + temperature=0.2, |
| 34 | + torch_dtype=torch.float16, |
| 35 | + streamer=streamer |
| 36 | + ) |
| 37 | + # pipe.model.config.pad_token_id = pipe.model.config.eos_token_id |
| 38 | + # llm_model = HuggingFacePipeline(pipeline=pipe) |
| 39 | + app._model = pipe |
| 40 | + app._streamer = streamer |
| 41 | + app._tokenizer = tokenizer |
| 42 | + app._model_name = model_name |
| 43 | + |
| 44 | +def init_code_models(app: FastAPI, args): |
| 45 | + model_name = args.model_names[0] |
| 46 | + model_path = args.model_path |
| 47 | + if len(model_name) == 0 or len(model_path) == 0: |
| 48 | + return |
| 49 | + configinst = InnerJsonConfigWebUIParse() |
| 50 | + webui_config = configinst.dump() |
| 51 | + model_info : Dict[str, any] = {"mtype": ModelType.Unknown, "msize": ModelSize.Unknown, "msubtype": ModelSubType.Unknown, "mname": str, "config": dict} |
| 52 | + model_info["mtype"], model_info["msize"], model_info["msubtype"] = GetModelInfoByName(webui_config, model_name) |
| 53 | + model_info["mname"] = model_name |
| 54 | + model_config = GetModelConfig(webui_config, model_info) |
| 55 | + load_type = model_config.get("load_type", "") |
| 56 | + if load_type == "causallm": |
| 57 | + load_causallm_model(app=app, model_name=model_name, model_path=model_path, device=args.device) |
| 58 | + elif load_type == "llama": |
| 59 | + load_llama_model(app=app, model_name=model_name, model_path=model_path, device=args.device) |
| 60 | + |
| 61 | +def code_model_chat( |
| 62 | + model: Any, |
| 63 | + tokenizer: Any, |
| 64 | + async_callback: Any, |
| 65 | + modelinfo: Any, |
| 66 | + query: str, |
| 67 | + imagesdata: List[str], |
| 68 | + audiosdata: List[str], |
| 69 | + videosdata: List[str], |
| 70 | + imagesprompt: List[str], |
| 71 | + history: List[dict], |
| 72 | + stream: bool, |
| 73 | + speechmodel: dict, |
| 74 | + temperature: float, |
| 75 | + max_tokens: Optional[int], |
| 76 | + prompt_name: str, |
| 77 | +): |
| 78 | + if modelinfo == None: |
| 79 | + return json.dumps( |
| 80 | + {"text": "Unusual error!", "chat_history_id": 123}, |
| 81 | + ensure_ascii=False) |
| 82 | + |
| 83 | + async def code_chat_iterator(model: Any, |
| 84 | + tokenizer: Any, |
| 85 | + async_callback: Any, |
| 86 | + query: str, |
| 87 | + imagesdata: List[str], |
| 88 | + audiosdata: List[str], |
| 89 | + videosdata: List[str], |
| 90 | + imagesprompt: List[str], |
| 91 | + history: List[dict] = [], |
| 92 | + modelinfo: Any = None, |
| 93 | + temperature: float = 0.7, |
| 94 | + max_tokens: Optional[int] = 2048, |
| 95 | + prompt_name: str = prompt_name, |
| 96 | + ) -> AsyncIterable[str]: |
| 97 | + |
| 98 | + from WebUI.Server.utils import detect_device |
| 99 | + configinst = InnerJsonConfigWebUIParse() |
| 100 | + webui_config = configinst.dump() |
| 101 | + model_name = modelinfo["mname"] |
| 102 | + speak_handler = None |
| 103 | + if len(speechmodel): |
| 104 | + modeltype = speechmodel.get("type", "") |
| 105 | + provider = speechmodel.get("provider", "") |
| 106 | + #spmodel = speechmodel.get("model", "") |
| 107 | + spspeaker = speechmodel.get("speaker", "") |
| 108 | + speechkey = speechmodel.get("speech_key", "") |
| 109 | + speechregion = speechmodel.get("speech_region", "") |
| 110 | + if modeltype == "local" or modeltype == "cloud": |
| 111 | + speak_handler = StreamSpeakHandler(run_place=modeltype, provider=provider, synthesis=spspeaker, subscription=speechkey, region=speechregion) |
| 112 | + |
| 113 | + answer = "" |
| 114 | + chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query) |
| 115 | + modelconfig = GetModelConfig(webui_config, modelinfo) |
| 116 | + device = modelconfig.get("device", "auto") |
| 117 | + device = "cuda" if device == "gpu" else detect_device() if device == "auto" else device |
| 118 | + if model_name == "stable-code-3b": |
| 119 | + if max_tokens is None: |
| 120 | + max_tokens = 512 |
| 121 | + inputs = tokenizer(query, return_tensors="pt").to(model.device) |
| 122 | + tokens = model.generate( |
| 123 | + **inputs, |
| 124 | + max_new_tokens=max_tokens, |
| 125 | + temperature=0.2, |
| 126 | + do_sample=True, |
| 127 | + ) |
| 128 | + answer = tokenizer.decode(tokens[0], skip_special_tokens=True) |
| 129 | + #answer = answer.replace(query, '').strip() |
| 130 | + #while(answer.startswith(("'", '"', ' ', ',', '.', '!', '?'))): |
| 131 | + # answer = answer[1:].strip() |
| 132 | + answer = "```python\n" + answer + "\n```" |
| 133 | + if speak_handler: speak_handler.on_llm_new_token(answer) |
| 134 | + yield json.dumps( |
| 135 | + {"text": answer, "chat_history_id": chat_history_id}, |
| 136 | + ensure_ascii=False) |
| 137 | + await asyncio.sleep(0.1) |
| 138 | + if speak_handler: speak_handler.on_llm_end(None) |
| 139 | + |
| 140 | + elif model_name == "CodeLlama-7b-Python-hf" or \ |
| 141 | + model_name == "CodeLlama-13b-Python-hf" or \ |
| 142 | + model_name == "CodeLlama-7b-Instruct-hf" or \ |
| 143 | + model_name == "CodeLlama-13b-Instruct-hf": |
| 144 | + sequences = model( |
| 145 | + query, |
| 146 | + #do_sample=True, |
| 147 | + temperature=0.2, |
| 148 | + top_p=0.95, |
| 149 | + num_return_sequences=1, |
| 150 | + eos_token_id=tokenizer.eos_token_id, |
| 151 | + max_length=512, |
| 152 | + ) |
| 153 | + answer = "" |
| 154 | + yield json.dumps( |
| 155 | + {"text": "```python\n", "chat_history_id": chat_history_id}, |
| 156 | + ensure_ascii=False) |
| 157 | + for seq in sequences: |
| 158 | + yield json.dumps( |
| 159 | + {"text": seq['generated_text'], "chat_history_id": chat_history_id}, |
| 160 | + ensure_ascii=False) |
| 161 | + await asyncio.sleep(0.1) |
| 162 | + yield json.dumps( |
| 163 | + {"text": "\n```", "chat_history_id": chat_history_id}, |
| 164 | + ensure_ascii=False) |
| 165 | + if speak_handler: |
| 166 | + speak_handler.on_llm_new_token(answer) |
| 167 | + speak_handler.on_llm_end(None) |
| 168 | + |
| 169 | + update_chat_history(chat_history_id, response=answer) |
| 170 | + |
| 171 | + return StreamingResponse(code_chat_iterator( |
| 172 | + model=model, |
| 173 | + tokenizer=tokenizer, |
| 174 | + async_callback=async_callback, |
| 175 | + query=query, |
| 176 | + imagesdata=imagesdata, |
| 177 | + audiosdata=audiosdata, |
| 178 | + videosdata=videosdata, |
| 179 | + imagesprompt=imagesprompt, |
| 180 | + history=history, |
| 181 | + modelinfo=modelinfo, |
| 182 | + temperature=temperature, |
| 183 | + max_tokens=max_tokens, |
| 184 | + prompt_name=prompt_name), |
| 185 | + media_type="text/event-stream") |
0 commit comments