Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions llm/inference/tinyllama/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import gradio as gr
import mindspore
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
from mindnlp.transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")


# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
stop_ids = [2] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return mindspore.Tensor(True)
return mindspore.Tensor(False)


# Function to generate model predictions.
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()

# Formatting the input for the model.
messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
for item in history_transformer_format])
model_inputs = tokenizer([messages], return_tensors="ms")
streamer = TextIteratorStreamer(tokenizer, timeout=120, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=10,
temperature=0.7,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message


# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
title="Tinyllama_chatBot",
description="Ask Tiny llama any questions",
examples=['How to cook a fish?', 'Who is the president of US now?']
).launch() # Launching the web interface.
26 changes: 26 additions & 0 deletions llm/inference/tinyllama/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## Tinyllama Chatbot Implementation with Gradio

We offer an easy way to interact with Tinyllama. This guide explains how to set up a local Gradio demo for a chatbot using TinyLlama.

### Requirements
* Python>=3.9
* MindSpore>=2.4
* MindNLP>=0.4
* Gradio>=4.13.0

### Installation
`pip install -r requirements.txt`

### Usage

`python app.py`

* After running it, open the local URL displayed in your terminal in your web browser. (For server setup, use SSH local port forwarding with the command: `ssh -L [local port]:localhost:[remote port] [username]@[server address]`.)
* Interact with the chatbot by typing questions or commands.


**Note:** If you are runing Tinyllama on OrangePi, please use the follow instruction to free memory first:

```bash
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches
```
3 changes: 3 additions & 0 deletions llm/inference/tinyllama/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mindspore>=2.4
mindnlp>=0.4
gradio>=4.13.0
6 changes: 3 additions & 3 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def tanh(input):
return ops.tanh(input)

def sigmoid(input):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return mindspore.mint.nn.functional.sigmoid(input)
return ops.sigmoid(input)

def silu(input):
if DEVICE_TARGET == 'CPU' or ON_ORANGE_PI:
return input * sigmoid(input)
if use_pyboost():
return mindspore.mint.nn.functional.silu(input)
if DEVICE_TARGET == 'CPU':
return input * sigmoid(input)
return ops.silu(input)

def mish(input):
Expand Down
7 changes: 6 additions & 1 deletion mindnlp/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,10 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac
param = Tensor.from_numpy(array)
return param

def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret

@dataclass
class FakeParameter:

Expand Down Expand Up @@ -1213,7 +1217,8 @@ def find_class(self, mod_name, name):
return eval(name)
if mod_name == 'torch':
return str(name)

if mod_name == 'torch._tensor':
return eval(name)
mod_name = load_module_mapping.get(mod_name, mod_name)
return super().find_class(mod_name, name)

Expand Down
9 changes: 7 additions & 2 deletions mindnlp/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import mindspore

from mindnlp.core import nn, ops
from mindnlp.configs import ON_ORANGE_PI
from .configuration_utils import PretrainedConfig
from ..utils import logging

Expand Down Expand Up @@ -446,8 +447,12 @@ def batch_repeat_interleave(self, repeats: int):
def batch_select_indices(self, indices: mindspore.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
if ON_ORANGE_PI:
self.key_cache[layer_idx] = ops.getitem(self.key_cache[layer_idx], (indices, Ellipsis))
self.value_cache[layer_idx] = ops.getitem(self.value_cache[layer_idx], (indices, Ellipsis))
else:
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]


class QuantizedCache(DynamicCache):
Expand Down
1 change: 1 addition & 0 deletions mindnlp/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,7 @@ def from_pretrained(
# "_commit_hash": commit_hash,
**has_file_kwargs,
}

if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
Thread(
target=auto_conversion,
Expand Down
33 changes: 21 additions & 12 deletions mindnlp/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from mindnlp.core import nn, ops, no_grad
import mindnlp.core.nn.functional as F
from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from mindnlp.configs import use_pyboost, SUPPORT_VIEW
from mindnlp.configs import use_pyboost, SUPPORT_VIEW, ON_ORANGE_PI
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_outputs import (
Expand Down Expand Up @@ -82,15 +82,17 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
if sequence_length != 1:
causal_mask = ops.triu(causal_mask, diagonal=1)
causal_mask *= ops.arange(target_length) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1))
# causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1))
# speed up by unsqueeze
causal_mask = causal_mask.view(1, 1, *causal_mask.shape).broadcast_to((batch_size, 1, -1, -1))
if attention_mask is not None:
if SUPPORT_VIEW:
causal_mask = causal_mask.contiguous() # copy to contiguous memory for in-place edit
else:
causal_mask = causal_mask.copy()
mask_length = attention_mask.shape[-1]
# padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask[:, None, None, :]
padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask.view(attention_mask.shape[0], 1, 1, attention_mask.shape[1])
padding_mask = padding_mask == 0
# causal_mask[:, :, :, :mask_length] = ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(
# padding_mask, min_dtype
Expand All @@ -117,7 +119,7 @@ def __init__(self, hidden_size, eps=1e-6):
self.variance_epsilon = eps

def forward(self, hidden_states):
if not self.training and use_pyboost():
if not self.training and use_pyboost() and not ON_ORANGE_PI:
return F.rms_norm(hidden_states, self.weight, self.variance_epsilon)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(mindspore.float32)
Expand Down Expand Up @@ -200,10 +202,10 @@ def forward(self, x, position_ids):
self._dynamic_frequency_update(position_ids)

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().broadcast_to((position_ids.shape[0], -1, 1))
position_ids_expanded = position_ids[:, None, :].float()
inv_freq_expanded = self.inv_freq.view(1, -1, 1).float().broadcast_to((position_ids.shape[0], -1, 1))
position_ids_expanded = ops.unsqueeze(position_ids, 1).float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
freqs = ops.transpose((inv_freq_expanded.float() @ position_ids_expanded.float()), 1, 2)
freqs = ops.transpose(ops.matmul(inv_freq_expanded.float(), position_ids_expanded.float()), 1, 2)
emb = ops.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down Expand Up @@ -423,7 +425,8 @@ def forward(
attn_weights = ops.matmul(query_states, ops.transpose(key_states, 2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
causal_mask = ops.narrow(attention_mask, 3, 0, key_states.shape[-2])
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
Expand Down Expand Up @@ -895,7 +898,8 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
if 0 not in input_ids.shape:
input_ids = input_ids[:, -cache_position.shape[0] :]
# input_ids = input_ids[:, -cache_position.shape[0] :]
input_ids = ops.narrow(input_ids, 1, input_ids.shape[1] - cache_position.shape[0], cache_position.shape[0])
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
# input_ids = input_ids[:, cache_position]
input_ids = ops.index_select(input_ids, -1, cache_position)
Expand All @@ -905,8 +909,8 @@ def prepare_inputs_for_generation(
position_ids = ops.cumsum(attention_mask.int(), -1) - 1
position_ids = ops.masked_fill(position_ids, attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# position_ids = position_ids[:, -input_ids.shape[1] :]
position_ids = ops.narrow(position_ids, 1, position_ids.shape[1] - input_ids.shape[1], input_ids.shape[1])
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down Expand Up @@ -1015,7 +1019,12 @@ def forward(
else:
sequence_lengths = -1

pooled_logits = logits[ops.arange(batch_size), sequence_lengths]
if ON_ORANGE_PI:
if isinstance(sequence_lengths, mindspore.Tensor):
sequence_lengths = sequence_lengths.to(mindspore.int32)
pooled_logits = ops.getitem(logits, (ops.arange(batch_size), sequence_lengths))
else:
pooled_logits = logits[ops.arange(batch_size), sequence_lengths]

loss = None
if labels is not None:
Expand Down
8 changes: 3 additions & 5 deletions mindnlp/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def http_get(url, path=None, md5sum=None, download_file_name=None, proxies=None,
req = requests.get(url, stream=True, timeout=10, proxies=proxies, headers=headers)

status = req.status_code
if status == 404:

if status in (404, 500):
raise EntryNotFoundError(f"Can not found url: {url}")
if status == 401:
raise GatedRepoError('You should have authorization to access the model.')
Expand Down Expand Up @@ -623,10 +624,7 @@ def download(
headers = {}
try:
pointer_path = http_get(url, storage_folder, download_file_name=relative_filename, proxies=proxies, headers=headers)
except (requests.exceptions.SSLError,
requests.exceptions.ProxyError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout):
except Exception:
# Otherwise, our Internet connection is down.
# etag is None
raise
Expand Down
42 changes: 34 additions & 8 deletions tests/ut/transformers/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

if is_mindspore_available():
import mindspore
from mindspore import ops
from mindnlp.core import ops

from mindnlp.transformers import (
CodeLlamaTokenizer,
Expand All @@ -47,7 +47,6 @@
LlamaTokenizer,
)


class LlamaModelTester:
def __init__(
self,
Expand Down Expand Up @@ -230,8 +229,8 @@ def create_and_check_decoder_model_past_large_inputs(
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)

# append to next input_ids and
next_input_ids = ops.cat([input_ids, next_tokens], axis=-1)
next_attention_mask = ops.cat([input_mask, next_mask], axis=-1)
next_input_ids = ops.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = ops.cat([input_mask, next_mask], dim=-1)

output_from_no_past = model(
next_input_ids,
Expand Down Expand Up @@ -414,18 +413,18 @@ def test_generate_padding_right(self):

@require_mindspore
class LlamaIntegrationTest(unittest.TestCase):
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
def test_model_7b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
model = LlamaForCausalLM.from_pretrained("shakechen/Llama-2-7b-hf", mirror='modelscope', ms_dtype=mindspore.float16)

out = model(mindspore.tensor([input_ids]))
# Expected mean on dim = -1
EXPECTED_MEAN = mindspore.tensor([[-6.6550, -4.1227, -4.9859, -3.2406, 0.8262, -3.0033, 1.2964, -3.3699]])
# torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
assert ops.allclose(out.logits.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = mindspore.tensor([-12.8281, -7.4453, -0.4639, -8.0625, -7.2500, -8.0000, -6.4883, -7.7695, -7.8438, -7.0312, -6.2188, -7.1328, -1.8496, 1.9961, -8.6250, -6.7227, -12.8281, -6.9492, -7.0742, -7.7852, -7.5820, -7.9062, -6.9375, -7.9805, -8.3438, -8.1562, -8.0469, -7.6250, -7.7422, -7.3398,]) # fmt: skip
# torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
assert ops.allclose(out.logits[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)

@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
Expand Down Expand Up @@ -567,3 +566,30 @@ def test_model_7b_logits(self):
]
infilling = tokenizer.batch_decode(generated_ids)
self.assertEqual(infilling, EXPECTED_INFILLING)


@require_mindspore
class TinyLlamaIntegrationTest(unittest.TestCase):
@slow
def test_model_1_1b_logits(self):
from mindnlp.transformers import AutoTokenizer, pipeline
model = "TinyLlama/TinyLlama_v1.1"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = pipeline(
"text-generation",
model=model,
ms_dtype=mindspore.float16,
# device_map="auto",
)

sequences = pipeline(
'The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.',
do_sample=True,
top_k=10,
num_return_sequences=1,
repetition_penalty=1.5,
eos_token_id=tokenizer.eos_token_id,
max_length=500,
)
for seq in sequences:
print(f"Result: {seq['generated_text']}")
Loading