diff --git a/llm/inference/tinyllama/app.py b/llm/inference/tinyllama/app.py new file mode 100644 index 000000000..e258329f6 --- /dev/null +++ b/llm/inference/tinyllama/app.py @@ -0,0 +1,58 @@ +import gradio as gr +import mindspore +from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer +from mindnlp.transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer +from threading import Thread + +# Loading the tokenizer and model from Hugging Face's model hub. +tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + + +# Defining a custom stopping criteria class for the model's text generation. +class StopOnTokens(StoppingCriteria): + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + stop_ids = [2] # IDs of tokens where the generation should stop. + for stop_id in stop_ids: + if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token. + return mindspore.Tensor(True) + return mindspore.Tensor(False) + + +# Function to generate model predictions. +def predict(message, history): + history_transformer_format = history + [[message, ""]] + stop = StopOnTokens() + + # Formatting the input for the model. + messages = "".join(["".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]]) + for item in history_transformer_format]) + model_inputs = tokenizer([messages], return_tensors="ms") + streamer = TextIteratorStreamer(tokenizer, timeout=120, skip_prompt=True, skip_special_tokens=True) + generate_kwargs = dict( + model_inputs, + streamer=streamer, + max_new_tokens=1024, + do_sample=True, + top_p=0.95, + top_k=10, + temperature=0.7, + num_beams=1, + stopping_criteria=StoppingCriteriaList([stop]) + ) + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() # Starting the generation in a separate thread. + partial_message = "" + for new_token in streamer: + partial_message += new_token + if '' in partial_message: # Breaking the loop if the stop token is generated. + break + yield partial_message + + +# Setting up the Gradio chat interface. +gr.ChatInterface(predict, + title="Tinyllama_chatBot", + description="Ask Tiny llama any questions", + examples=['How to cook a fish?', 'Who is the president of US now?'] + ).launch() # Launching the web interface. diff --git a/llm/inference/tinyllama/readme.md b/llm/inference/tinyllama/readme.md new file mode 100644 index 000000000..93f60c58e --- /dev/null +++ b/llm/inference/tinyllama/readme.md @@ -0,0 +1,26 @@ +## Tinyllama Chatbot Implementation with Gradio + +We offer an easy way to interact with Tinyllama. This guide explains how to set up a local Gradio demo for a chatbot using TinyLlama. + +### Requirements +* Python>=3.9 +* MindSpore>=2.4 +* MindNLP>=0.4 +* Gradio>=4.13.0 + +### Installation +`pip install -r requirements.txt` + +### Usage + +`python app.py` + +* After running it, open the local URL displayed in your terminal in your web browser. (For server setup, use SSH local port forwarding with the command: `ssh -L [local port]:localhost:[remote port] [username]@[server address]`.) +* Interact with the chatbot by typing questions or commands. + + +**Note:** If you are runing Tinyllama on OrangePi, please use the follow instruction to free memory first: + +```bash +sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches +``` \ No newline at end of file diff --git a/llm/inference/tinyllama/requirements.txt b/llm/inference/tinyllama/requirements.txt new file mode 100644 index 000000000..6d068c1a5 --- /dev/null +++ b/llm/inference/tinyllama/requirements.txt @@ -0,0 +1,3 @@ +mindspore>=2.4 +mindnlp>=0.4 +gradio>=4.13.0 diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 2524a0daf..7b3499f13 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -28,15 +28,15 @@ def tanh(input): return ops.tanh(input) def sigmoid(input): - if use_pyboost(): + if use_pyboost() and not ON_ORANGE_PI: return mindspore.mint.nn.functional.sigmoid(input) return ops.sigmoid(input) def silu(input): + if DEVICE_TARGET == 'CPU' or ON_ORANGE_PI: + return input * sigmoid(input) if use_pyboost(): return mindspore.mint.nn.functional.silu(input) - if DEVICE_TARGET == 'CPU': - return input * sigmoid(input) return ops.silu(input) def mish(input): diff --git a/mindnlp/core/serialization.py b/mindnlp/core/serialization.py index 11b3eda85..85c44a633 100644 --- a/mindnlp/core/serialization.py +++ b/mindnlp/core/serialization.py @@ -800,6 +800,10 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac param = Tensor.from_numpy(array) return param +def _rebuild_from_type_v2(func, new_type, args, state): + ret = func(*args) + return ret + @dataclass class FakeParameter: @@ -1213,7 +1217,8 @@ def find_class(self, mod_name, name): return eval(name) if mod_name == 'torch': return str(name) - + if mod_name == 'torch._tensor': + return eval(name) mod_name = load_module_mapping.get(mod_name, mod_name) return super().find_class(mod_name, name) diff --git a/mindnlp/transformers/cache_utils.py b/mindnlp/transformers/cache_utils.py index 310df2eac..009af54e3 100644 --- a/mindnlp/transformers/cache_utils.py +++ b/mindnlp/transformers/cache_utils.py @@ -24,6 +24,7 @@ import mindspore from mindnlp.core import nn, ops +from mindnlp.configs import ON_ORANGE_PI from .configuration_utils import PretrainedConfig from ..utils import logging @@ -446,8 +447,12 @@ def batch_repeat_interleave(self, repeats: int): def batch_select_indices(self, indices: mindspore.Tensor): """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + if ON_ORANGE_PI: + self.key_cache[layer_idx] = ops.getitem(self.key_cache[layer_idx], (indices, Ellipsis)) + self.value_cache[layer_idx] = ops.getitem(self.value_cache[layer_idx], (indices, Ellipsis)) + else: + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] class QuantizedCache(DynamicCache): diff --git a/mindnlp/transformers/modeling_utils.py b/mindnlp/transformers/modeling_utils.py index b8242ffcd..cef8b7b30 100644 --- a/mindnlp/transformers/modeling_utils.py +++ b/mindnlp/transformers/modeling_utils.py @@ -2766,6 +2766,7 @@ def from_pretrained( # "_commit_hash": commit_hash, **has_file_kwargs, } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): Thread( target=auto_conversion, diff --git a/mindnlp/transformers/models/llama/modeling_llama.py b/mindnlp/transformers/models/llama/modeling_llama.py index b6062640a..5e30dff2e 100644 --- a/mindnlp/transformers/models/llama/modeling_llama.py +++ b/mindnlp/transformers/models/llama/modeling_llama.py @@ -25,7 +25,7 @@ from mindnlp.core import nn, ops, no_grad import mindnlp.core.nn.functional as F from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from mindnlp.configs import use_pyboost, SUPPORT_VIEW +from mindnlp.configs import use_pyboost, SUPPORT_VIEW, ON_ORANGE_PI from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_outputs import ( @@ -82,7 +82,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if sequence_length != 1: causal_mask = ops.triu(causal_mask, diagonal=1) causal_mask *= ops.arange(target_length) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1)) + # causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1)) + # speed up by unsqueeze + causal_mask = causal_mask.view(1, 1, *causal_mask.shape).broadcast_to((batch_size, 1, -1, -1)) if attention_mask is not None: if SUPPORT_VIEW: causal_mask = causal_mask.contiguous() # copy to contiguous memory for in-place edit @@ -90,7 +92,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = causal_mask.copy() mask_length = attention_mask.shape[-1] # padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask[:, None, None, :] + padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask.view(attention_mask.shape[0], 1, 1, attention_mask.shape[1]) padding_mask = padding_mask == 0 # causal_mask[:, :, :, :mask_length] = ops.narrow(causal_mask, -1, 0, mask_length).masked_fill( # padding_mask, min_dtype @@ -117,7 +119,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - if not self.training and use_pyboost(): + if not self.training and use_pyboost() and not ON_ORANGE_PI: return F.rms_norm(hidden_states, self.weight, self.variance_epsilon) input_dtype = hidden_states.dtype hidden_states = hidden_states.to(mindspore.float32) @@ -200,10 +202,10 @@ def forward(self, x, position_ids): self._dynamic_frequency_update(position_ids) # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().broadcast_to((position_ids.shape[0], -1, 1)) - position_ids_expanded = position_ids[:, None, :].float() + inv_freq_expanded = self.inv_freq.view(1, -1, 1).float().broadcast_to((position_ids.shape[0], -1, 1)) + position_ids_expanded = ops.unsqueeze(position_ids, 1).float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - freqs = ops.transpose((inv_freq_expanded.float() @ position_ids_expanded.float()), 1, 2) + freqs = ops.transpose(ops.matmul(inv_freq_expanded.float(), position_ids_expanded.float()), 1, 2) emb = ops.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() @@ -423,7 +425,8 @@ def forward( attn_weights = ops.matmul(query_states, ops.transpose(key_states, 2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + causal_mask = ops.narrow(attention_mask, 3, 0, key_states.shape[-2]) attn_weights = attn_weights + causal_mask # upcast attention to fp32 @@ -895,7 +898,8 @@ def prepare_inputs_for_generation( if past_key_values is not None: if inputs_embeds is not None: # Exception 1 if 0 not in input_ids.shape: - input_ids = input_ids[:, -cache_position.shape[0] :] + # input_ids = input_ids[:, -cache_position.shape[0] :] + input_ids = ops.narrow(input_ids, 1, input_ids.shape[1] - cache_position.shape[0], cache_position.shape[0]) elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) # input_ids = input_ids[:, cache_position] input_ids = ops.index_select(input_ids, -1, cache_position) @@ -905,8 +909,8 @@ def prepare_inputs_for_generation( position_ids = ops.cumsum(attention_mask.int(), -1) - 1 position_ids = ops.masked_fill(position_ids, attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - + # position_ids = position_ids[:, -input_ids.shape[1] :] + position_ids = ops.narrow(position_ids, 1, position_ids.shape[1] - input_ids.shape[1], input_ids.shape[1]) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1015,7 +1019,12 @@ def forward( else: sequence_lengths = -1 - pooled_logits = logits[ops.arange(batch_size), sequence_lengths] + if ON_ORANGE_PI: + if isinstance(sequence_lengths, mindspore.Tensor): + sequence_lengths = sequence_lengths.to(mindspore.int32) + pooled_logits = ops.getitem(logits, (ops.arange(batch_size), sequence_lengths)) + else: + pooled_logits = logits[ops.arange(batch_size), sequence_lengths] loss = None if labels is not None: diff --git a/mindnlp/utils/download.py b/mindnlp/utils/download.py index fd1b11b5a..d2ca302d6 100644 --- a/mindnlp/utils/download.py +++ b/mindnlp/utils/download.py @@ -198,7 +198,8 @@ def http_get(url, path=None, md5sum=None, download_file_name=None, proxies=None, req = requests.get(url, stream=True, timeout=10, proxies=proxies, headers=headers) status = req.status_code - if status == 404: + + if status in (404, 500): raise EntryNotFoundError(f"Can not found url: {url}") if status == 401: raise GatedRepoError('You should have authorization to access the model.') @@ -623,10 +624,7 @@ def download( headers = {} try: pointer_path = http_get(url, storage_folder, download_file_name=relative_filename, proxies=proxies, headers=headers) - except (requests.exceptions.SSLError, - requests.exceptions.ProxyError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout): + except Exception: # Otherwise, our Internet connection is down. # etag is None raise diff --git a/tests/ut/transformers/models/llama/test_modeling_llama.py b/tests/ut/transformers/models/llama/test_modeling_llama.py index c607f4c89..2f32f40eb 100644 --- a/tests/ut/transformers/models/llama/test_modeling_llama.py +++ b/tests/ut/transformers/models/llama/test_modeling_llama.py @@ -37,7 +37,7 @@ if is_mindspore_available(): import mindspore - from mindspore import ops + from mindnlp.core import ops from mindnlp.transformers import ( CodeLlamaTokenizer, @@ -47,7 +47,6 @@ LlamaTokenizer, ) - class LlamaModelTester: def __init__( self, @@ -230,8 +229,8 @@ def create_and_check_decoder_model_past_large_inputs( next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) # append to next input_ids and - next_input_ids = ops.cat([input_ids, next_tokens], axis=-1) - next_attention_mask = ops.cat([input_mask, next_mask], axis=-1) + next_input_ids = ops.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = ops.cat([input_mask, next_mask], dim=-1) output_from_no_past = model( next_input_ids, @@ -414,18 +413,18 @@ def test_generate_padding_right(self): @require_mindspore class LlamaIntegrationTest(unittest.TestCase): - @unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") @slow def test_model_7b_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto") + model = LlamaForCausalLM.from_pretrained("shakechen/Llama-2-7b-hf", mirror='modelscope', ms_dtype=mindspore.float16) + out = model(mindspore.tensor([input_ids])) # Expected mean on dim = -1 EXPECTED_MEAN = mindspore.tensor([[-6.6550, -4.1227, -4.9859, -3.2406, 0.8262, -3.0033, 1.2964, -3.3699]]) - # torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + assert ops.allclose(out.logits.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) # slicing logits[0, 0, 0:30] EXPECTED_SLICE = mindspore.tensor([-12.8281, -7.4453, -0.4639, -8.0625, -7.2500, -8.0000, -6.4883, -7.7695, -7.8438, -7.0312, -6.2188, -7.1328, -1.8496, 1.9961, -8.6250, -6.7227, -12.8281, -6.9492, -7.0742, -7.7852, -7.5820, -7.9062, -6.9375, -7.9805, -8.3438, -8.1562, -8.0469, -7.6250, -7.7422, -7.3398,]) # fmt: skip - # torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) + assert ops.allclose(out.logits[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) @unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") @slow @@ -567,3 +566,30 @@ def test_model_7b_logits(self): ] infilling = tokenizer.batch_decode(generated_ids) self.assertEqual(infilling, EXPECTED_INFILLING) + + +@require_mindspore +class TinyLlamaIntegrationTest(unittest.TestCase): + @slow + def test_model_1_1b_logits(self): + from mindnlp.transformers import AutoTokenizer, pipeline + model = "TinyLlama/TinyLlama_v1.1" + tokenizer = AutoTokenizer.from_pretrained(model) + pipeline = pipeline( + "text-generation", + model=model, + ms_dtype=mindspore.float16, + # device_map="auto", + ) + + sequences = pipeline( + 'The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.', + do_sample=True, + top_k=10, + num_return_sequences=1, + repetition_penalty=1.5, + eos_token_id=tokenizer.eos_token_id, + max_length=500, + ) + for seq in sequences: + print(f"Result: {seq['generated_text']}")