diff --git a/llm/inference/qwen1.5-0.5b/app.py b/llm/inference/qwen1.5-0.5b/app.py new file mode 100644 index 000000000..24f7da026 --- /dev/null +++ b/llm/inference/qwen1.5-0.5b/app.py @@ -0,0 +1,58 @@ +import gradio as gr +import mindspore +from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer +from mindnlp.transformers import TextIteratorStreamer +from threading import Thread + +# Loading the tokenizer and model from Hugging Face's model hub. +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16) +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16) + +system_prompt = "You are a helpful and friendly chatbot" + +def build_input_from_chat_history(chat_history, msg: str): + messages = [{'role': 'system', 'content': system_prompt}] + for user_msg, ai_msg in chat_history: + messages.append({'role': 'user', 'content': user_msg}) + messages.append({'role': 'assistant', 'content': ai_msg}) + messages.append({'role': 'user', 'content': msg}) + return messages + +# Function to generate model predictions. +def predict(message, history): + history_transformer_format = history + [[message, ""]] + + # Formatting the input for the model. + messages = build_input_from_chat_history(history, message) + input_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="ms", + tokenize=True + ) + streamer = TextIteratorStreamer(tokenizer, timeout=120, skip_prompt=True, skip_special_tokens=True) + generate_kwargs = dict( + input_ids=input_ids, + streamer=streamer, + max_new_tokens=1024, + do_sample=True, + top_p=0.9, + temperature=0.1, + num_beams=1, + ) + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() # Starting the generation in a separate thread. + partial_message = "" + for new_token in streamer: + partial_message += new_token + if '' in partial_message: # Breaking the loop if the stop token is generated. + break + yield partial_message + + +# Setting up the Gradio chat interface. +gr.ChatInterface(predict, + title="Qwen1.5-0.5b-Chat", + description="问几个问题", + examples=['你是谁?', '介绍一下华为公司'] + ).launch() # Launching the web interface. diff --git a/llm/inference/tinyllama/app_quant.py b/llm/inference/tinyllama/app_quant.py new file mode 100644 index 000000000..048631ffc --- /dev/null +++ b/llm/inference/tinyllama/app_quant.py @@ -0,0 +1,62 @@ +import gradio as gr +import mindspore +from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer +from mindnlp.transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer +from mindnlp.quant.smooth_quant import quantize, w8x8 +from threading import Thread + +mindspore.set_context(pynative_synchronize=True) +# Loading the tokenizer and model from Hugging Face's model hub. +tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", ms_dtype=mindspore.float16) + +quantize_cfg = w8x8(model.model.config) +quantize(model, cfg=quantize_cfg) + +# Defining a custom stopping criteria class for the model's text generation. +class StopOnTokens(StoppingCriteria): + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + stop_ids = [2] # IDs of tokens where the generation should stop. + for stop_id in stop_ids: + if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token. + return mindspore.Tensor(True) + return mindspore.Tensor(False) + + +# Function to generate model predictions. +def predict(message, history): + history_transformer_format = history + [[message, ""]] + stop = StopOnTokens() + + # Formatting the input for the model. + messages = "".join(["".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]]) + for item in history_transformer_format]) + model_inputs = tokenizer([messages], return_tensors="ms") + streamer = TextIteratorStreamer(tokenizer, timeout=3600, skip_prompt=True, skip_special_tokens=True) + generate_kwargs = dict( + model_inputs, + streamer=streamer, + max_new_tokens=1024, + do_sample=True, + top_p=0.95, + top_k=10, + temperature=0.7, + num_beams=1, + stopping_criteria=StoppingCriteriaList([stop]) + ) + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() # Starting the generation in a separate thread. + partial_message = "" + for new_token in streamer: + partial_message += new_token + if '' in partial_message: # Breaking the loop if the stop token is generated. + break + yield partial_message + + +# Setting up the Gradio chat interface. +gr.ChatInterface(predict, + title="Tinyllama_chatBot", + description="Ask Tiny llama any questions", + examples=['How to cook a fish?', 'Who is the president of US now?'] + ).launch() # Launching the web interface. diff --git a/mindnlp/core/ops/comparison.py b/mindnlp/core/ops/comparison.py index e134a84d2..45a9975ae 100644 --- a/mindnlp/core/ops/comparison.py +++ b/mindnlp/core/ops/comparison.py @@ -2,7 +2,7 @@ import numpy as np import mindspore from mindspore import ops -from mindnlp.configs import use_pyboost, ON_ORANGE_PI +from mindnlp.configs import use_pyboost # allclose def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): @@ -128,8 +128,6 @@ def not_equal(input, other): # sort def sort(input, *, dim=-1, descending=False, stable=False): - if ON_ORANGE_PI: - return topk(input, input.shape[dim], dim, descending) if use_pyboost(): return mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable) return ops.sort(input, dim, descending) diff --git a/mindnlp/quant/__init__.py b/mindnlp/quant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/quant/smooth_quant/__init__.py b/mindnlp/quant/smooth_quant/__init__.py new file mode 100644 index 000000000..5d4e9bf8c --- /dev/null +++ b/mindnlp/quant/smooth_quant/__init__.py @@ -0,0 +1,3 @@ +"""smooth quant""" +from .quant import * +from .configs import * diff --git a/mindnlp/quant/smooth_quant/configs.py b/mindnlp/quant/smooth_quant/configs.py new file mode 100644 index 000000000..9ae9bbd94 --- /dev/null +++ b/mindnlp/quant/smooth_quant/configs.py @@ -0,0 +1,135 @@ +"""quant configs""" +def no(model_cfg, act_max): + return {} + + +# 静态混合精度分解 +def sd(model_cfg, act_max): + quant_cfg = {} + h_mx, d_mx = findN(0.04 * model_cfg.hidden_size), findN( + 0.1 * model_cfg.intermediate_size + ) + scale, step = 4, 4 / model_cfg.num_hidden_layers + for i in range(model_cfg.num_hidden_layers): + scale = max(0, scale - step) + h_cur, d_cur = max(16, h_mx >> int(scale)), max(32, d_mx >> int(scale)) + for name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"]: + quant_cfg[str(i) + "." + name] = { + "type": "W8SD", + "act_scale": True, + "alpha": h_cur, + } + quant_cfg[str(i) + ".down_proj"] = { + "type": "W8SD", + "act_scale": True, + "alpha": d_cur, + } + quant_cfg["lm_head"] = {"type": "W8SD"} + quant_cfg["act_scales_path"] = act_max + return quant_cfg + + +def findN(N): + sum = 1 + while True: + if sum * 2 > N: + return sum + sum = sum * 2 + + +# 平滑激活 +def smooth(model_cfg, act_max): + quant_cfg = {} + for i in range(model_cfg.num_hidden_layers): + for name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"]: + quant_cfg[str(i) + "." + name] = {"type": "W8X8"} + # 对某一个具体的层加act_scale的作用: 若为W8X8,则对该层进行smooth;如为W8SD,则用act_scale进行混合精度分解。 + quant_cfg[str(i) + ".down_proj"] = { + "type": "W8X8", + "act_scale": True, + "alpha": 0.85, + } + quant_cfg["lm_head"] = {"type": "W8X8", "act_scale": True, "alpha": 0.85} + quant_cfg["act_scales_path"] = act_max + quant_cfg["alpha"] = 0.85 # smoothquant 迁移系数 + quant_cfg["smooth"] = ( + True # 整体的smooth控制是将激活值的缩放与RMSNorm融合,不会造成额外的开销,但down_proj层无法使用 + ) + return quant_cfg + + +# 对down_proj混合精度分解,对其他部分平滑激活 +def smsd(model_cfg, act_max): + quant_cfg = {} + d_mx = findN(0.1 * model_cfg.intermediate_size) + scale, step = 4, 4 / model_cfg.num_hidden_layers + for i in range(model_cfg.num_hidden_layers): + scale = max(0, scale - step) + d_cur = max(32, d_mx >> int(scale)) + for name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"]: + quant_cfg[str(i) + "." + name] = {"type": "W8X8"} + quant_cfg[str(i) + ".down_proj"] = { + "type": "W8SD", + "act_scale": True, + "alpha": d_cur, + } + quant_cfg["lm_head"] = {"type": "W8SD", "act_scale": True, "alpha": 64} + quant_cfg["act_scales_path"] = act_max + quant_cfg["smooth"] = True + return quant_cfg + + +# 仅权重int8量化 +def w8(model_cfg, act_max): + quant_cfg = {} + for i in range(model_cfg.num_hidden_layers): + for name in [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ]: + quant_cfg[str(i) + "." + name] = {"type": "W8"} + quant_cfg["lm_head"] = {"type": "W8"} + return quant_cfg + + +# 动态混合精度分解 +def w8dx(model_cfg, act_max): + quant_cfg = {} + for i in range(model_cfg.num_hidden_layers): + for name in [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ]: + quant_cfg[str(i) + "." + name] = {"type": "W8DX"} + # quant_cfg["lm_head"] = {"type":"W8DX"} # 可以根据需要取消注释 + # quant_cfg["act_scales_path"] = act_max # 可以根据需要取消注释 + # quant_cfg["smooth"] = True # 可以根据需要取消注释 + return quant_cfg + + +# per-token absmax量化 +def w8x8(model_cfg): + quant_cfg = {} + for i in range(model_cfg.num_hidden_layers): + for name in [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ]: + quant_cfg[str(i) + "." + name] = {"type": "W8X8"} + quant_cfg["lm_head"] = {"type": "W8X8"} + return quant_cfg diff --git a/mindnlp/quant/smooth_quant/quant.py b/mindnlp/quant/smooth_quant/quant.py new file mode 100644 index 000000000..ca265a4cf --- /dev/null +++ b/mindnlp/quant/smooth_quant/quant.py @@ -0,0 +1,240 @@ +"""quant implement""" +from typing import Optional, Tuple +import mindspore +from mindspore import Tensor +from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register +from mindnlp.core import nn, ops +from mindnlp.core.serialization import load +from mindnlp.configs import ON_ORANGE_PI + +from .smooth import smooth_lm + + +class BatchMatMulV2(PrimitiveWithInfer): + @prim_attr_register + def __init__(self, transpose_a=False, transpose_b=False): # pylint: disable=super-init-not-called + """Initialize BatchMatMul.""" + self.init_prim_io_names(inputs=["x1", "x2", "bias", "offset_w"], outputs=["y"]) + self.add_prim_attr("adj_x1", self.transpose_a) + self.add_prim_attr("adj_x2", self.transpose_b) + + def infer_shape(self, x1_shape, x2_shape, bias_shape=None): + return x1_shape[:-1] + [x2_shape[-1]] + + def infer_dtype(self, x_dtype, v_dtype, bias_dtype=None): + return x_dtype + + +matmulInteger = BatchMatMulV2() + + +def quantize_mat(mat: Tensor) -> Tuple[Tensor, Tensor]: + max_val = (ops.max(ops.abs(mat), dim=-1)[0] / 127.0).to(dtype=mat.dtype) + mat = (mat / max_val[..., None]).to(dtype=mindspore.int8) + return mat, max_val + + +def dequantize_mat(mat: Tensor, max_val: Tensor): + return ops.mul(mat, max_val.unsqueeze(-1)) + + +def decomposition(mat: Tensor, unq_idx: Tensor, t: Tensor): + return mat.mul(t.to(dtype=mat.dtype)), mat[..., unq_idx] + # mat = mat.clone() + # mat_unq = mat[..., unq_idx] + # if mat.dim() == 3: + # mat[:, :, unq_idx] = 0 + # elif mat.dim() == 4: + # mat[:, :, :, unq_idx] = 0 + # elif mat.dim() == 2: + # mat[:, unq_idx] = 0 + # return mat, mat_unq + + +def get_unq_idx_topk(mat: Tensor, k: int = 64): + idx = ops.topk(ops.max(mat.view(-1, mat.shape[-1]).abs(), dim=-2)[0], k, dim=-1)[1] + t = ops.ones((mat.shape[-1]), dtype=mat.dtype) + t = t.copy() + if ON_ORANGE_PI: + ops.setitem(t, idx, 0) + else: + t[idx] = 0 + return idx, t + + +def get_unq_idx_thres(mat: Tensor, threshold: float = 6.0): + k = ops.max(mat.view(-1, mat.shape[-1]).abs(), dim=-2)[0] >= threshold + return ops.nonzero(k).view(-1), k + + +def qMatmul(x_q: Tensor, x_max: Tensor, weight_q: Tensor, w_max: Tensor, dtype): + res_q = matmulInteger(x_q, weight_q) + mx = nn.functional.linear(x_max.unsqueeze(-1), w_max.unsqueeze(-1)) + res = ops.mul(res_q.to(dtype=mindspore.float32), mx.to(mindspore.float32)).to(dtype=dtype) + return res + + +class W8Linear(nn.Module): + def __init__( + self, + origin_weight: Tensor, + bias: Optional[Tensor] = None, + act_max: Optional[Tensor] = None, + alpha=32, + ): + super().__init__() + self.bias = None if bias is None else nn.Parameter(bias, requires_grad=False) + self.dtype = origin_weight.dtype + self.alpha = alpha + self.weight_q, self.max_val = quantize_mat(origin_weight) + self.weight_q = nn.Parameter(self.weight_q, requires_grad=False) + self.max_val = nn.Parameter(self.max_val, requires_grad=False) + + def forward(self, x: Tensor) -> Tensor: + return nn.functional.linear( + x, dequantize_mat(self.weight_q, self.max_val), bias=self.bias + ) + + +# act_max for smooth +class W8X8Linear(nn.Module): + def __init__( + self, + ori_w: Tensor, + bias: Optional[Tensor] = None, + act_max: Optional[Tensor] = None, + alpha=32, + ): + super().__init__() + self.bias = None if bias is None else nn.Parameter(bias, requires_grad=False) + self.dtype = ori_w.dtype + self.alpha = alpha + self.scales = None + if act_max is not None: + self.scales = ( + (act_max.pow(alpha) / ops.max(ori_w.abs(), dim=0)[0].pow(1 - alpha)) + .clamp(min=1e-5) + .to(dtype=ori_w.dtype) + ) + self.scales = nn.Parameter(self.scales, requires_grad=False) + ori_w = ori_w.mul(self.scales) + self.weight_q, self.max_val = quantize_mat(ori_w) + self.weight_q = nn.Parameter(self.weight_q.t(), requires_grad=False) + self.max_val = nn.Parameter(self.max_val, requires_grad=False) + + def forward(self, x: Tensor) -> Tensor: + if self.scales is not None: + x = x.div(self.scales) + x_q, x_max = quantize_mat(x) + res = qMatmul(x_q, x_max, self.weight_q, self.max_val, x.dtype) + if self.bias is not None: + res = res + self.bias + return res + + +# static decomposition +class W8SDLinear(nn.Module): + def __init__( + self, + origin_weight: Tensor, + bias: Optional[Tensor] = None, + act_max: Optional[Tensor] = None, + alpha=32, + ): + super().__init__() + self.bias = None if bias is None else nn.Parameter(bias, requires_grad=False) + self.dtype = origin_weight.dtype + self.alpha = alpha + if act_max is not None: + self.idx_unq, self.t = get_unq_idx_topk(act_max, self.alpha) + else: + self.idx_unq, self.t = get_unq_idx_topk(origin_weight, self.alpha) + + self.weight_q, self.weight_unq = decomposition( + origin_weight, self.idx_unq, self.t + ) + self.weight_q, self.w_max = quantize_mat(self.weight_q) + self.weight_q = nn.Parameter(self.weight_q.t(), requires_grad=False) + self.weight_unq = nn.Parameter(self.weight_unq.t(), requires_grad=False) + self.w_max = nn.Parameter(self.w_max, requires_grad=False) + self.t = nn.Parameter(self.t, requires_grad=False) + self.idx_unq = nn.Parameter(self.idx_unq, requires_grad=False) + + def forward(self, x: Tensor) -> Tensor: + x_q, x_unq = decomposition(x, self.idx_unq, self.t) + x_q, x_max = quantize_mat(x_q) + res_q = qMatmul(x_q, x_max, self.weight_q, self.w_max, x.dtype) + res_unq = ops.matmul(x_unq, self.weight_unq) + if self.bias is not None: + res_unq += self.bias + return res_q + res_unq + + +class W8DXLinear(nn.Module): + def __init__( + self, + origin_weight: Tensor, + bias: Optional[Tensor] = None, + act_max: Optional[Tensor] = None, + alpha=32, + ): + super().__init__() + self.bias = None if bias is None else nn.Parameter(bias, requires_grad=False) + self.dtype = origin_weight.dtype + self.alpha = alpha + self.weight_q, self.max_val = quantize_mat(origin_weight) + self.weight_q = nn.Parameter(self.weight_q.t(), requires_grad=False) + self.max_val = nn.Parameter(self.max_val, requires_grad=False) + + def forward(self, x: Tensor) -> Tensor: + idx_unq, t = get_unq_idx_topk(x, self.alpha) + x_q, x_unq = decomposition(x, idx_unq, t) + x_q, x_max = quantize_mat(x_q) + res_q = qMatmul(x_q, x_max, self.weight_q, self.max_val, x.dtype) + weight_unq = ops.mul(self.weight_q[idx_unq, :], self.max_val.unsqueeze(0)) + res_unq = ops.matmul(x_unq, weight_unq) + if self.bias is not None: + res_unq += self.bias + return res_q + res_unq + + +quant_cls = {"W8": W8Linear, "W8X8": W8X8Linear, "W8SD": W8SDLinear, "W8DX": W8DXLinear} + + +def replace_linear_modules(module: nn.Module, prefix: str, act_scales, cfg): + for name, child in module.named_children(): + fullname = (prefix + "." + name) if prefix != "" else name + if isinstance(child, nn.Linear): + strs = fullname.split(".") + # fullname: model.layers.21.self_attn.q_proj layer_name: 21.q_proj; name: q_proj + # fullname: lm_head; layer_name: 21.q_proj; name: q_proj; + layer_name = (strs[-3] + "." + strs[-1]) if len(strs) > 2 else strs[-1] + + if layer_name not in cfg: + continue + act_scale = ( + None + if act_scales is None or "act_scale" not in cfg[layer_name] + else act_scales[fullname] + ) + alpha = None if "alpha" not in cfg[layer_name] else cfg[layer_name]["alpha"] + setattr( + module, + name, + quant_cls[cfg[layer_name]["type"]]( + child.weight, child.bias, act_max=act_scale, alpha=alpha + ), + ) + else: + replace_linear_modules(child, fullname, act_scales, cfg) + + +def quantize(model: nn.Module, cfg={}): + act_scales = None + if "act_scales_path" in cfg: + act_scales = load(cfg["act_scales_path"]) + if "smooth" in cfg: + + alpha = 0.85 if "alpha" not in cfg else cfg["alpha"] + smooth_lm(model, act_scales, alpha) + replace_linear_modules(model, "", act_scales, cfg) diff --git a/mindnlp/quant/smooth_quant/smooth.py b/mindnlp/quant/smooth_quant/smooth.py new file mode 100644 index 000000000..04d13cd46 --- /dev/null +++ b/mindnlp/quant/smooth_quant/smooth.py @@ -0,0 +1,53 @@ +''' +code from https://github.com/mit-han-lab/smoothquant/ +''' +from mindnlp.core import ops, nn, no_grad + +from mindnlp.transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm + +@no_grad() +def smooth_ln_fcs_llama_like(ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + assert isinstance(ln, (LlamaRMSNorm, nn.Linear)) + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.shape[0] == fc.in_features == act_scales.numel() + dtype = fcs[0].weight.dtype + act_scales = act_scales.to(dtype=dtype) + weight_scales = ops.cat( + [ops.max(fc.weight.abs(), dim=0, keepdim=True)[0] for fc in fcs], dim=0 + ) + weight_scales = ops.max(weight_scales, dim=0)[0].clamp(min=1e-5) + scales = ( + (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)) + .clamp(min=1e-5) + .to(dtype) + ) + if ln.weight.dim() == 2: + ln.weight = ln.weight.div(scales.unsqueeze(-1)) + else: + ln.weight = ln.weightdiv(scales) + for fc in fcs: + fc.weight = fc.weight.mul(scales.view(1, -1)) + + +@no_grad() +def smooth_lm(model, scales, alpha=0.5): + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm # attention forward norm + qkv = [ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ] + + qkv_input_scales = scales[name + ".self_attn.q_proj"] + smooth_ln_fcs_llama_like(attn_ln, qkv, qkv_input_scales, alpha) + + ffn_ln = module.post_attention_layernorm # feed forward norm + fcs = [module.mlp.gate_proj, module.mlp.up_proj] + fcs_input_scales = scales[name + ".mlp.gate_proj"] + + smooth_ln_fcs_llama_like(ffn_ln, fcs, fcs_input_scales, alpha) diff --git a/mindnlp/transformers/generation/logits_process.py b/mindnlp/transformers/generation/logits_process.py index 3806a0281..004c3cc21 100644 --- a/mindnlp/transformers/generation/logits_process.py +++ b/mindnlp/transformers/generation/logits_process.py @@ -466,6 +466,10 @@ def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> min if self.filter_value == -float("Inf"): self.filter_value = float(ops.finfo(scores.dtype).min) + + if ON_ORANGE_PI: + return self.tf_like_call(input_ids, scores) + sorted_logits, sorted_indices = ops.sort(scores, descending=False) cumulative_probs = ops.cumsum(ops.softmax(sorted_logits, dim=-1), dim=-1) diff --git a/mindnlp/transformers/generation/stopping_criteria.py b/mindnlp/transformers/generation/stopping_criteria.py index ea74a0c45..dac16fefe 100644 --- a/mindnlp/transformers/generation/stopping_criteria.py +++ b/mindnlp/transformers/generation/stopping_criteria.py @@ -25,6 +25,7 @@ from mindnlp.core import ops from mindnlp.core.nn import functional as F +from ..ms_utils import isin_friendly from ..tokenization_utils_base import PreTrainedTokenizerBase from ...utils import logging @@ -470,7 +471,7 @@ def __init__(self, eos_token_id: Union[int, List[int], mindspore.Tensor]): self.eos_token_id = eos_token_id def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> mindspore.Tensor: - is_done = ops.isin(input_ids[:, -1], self.eos_token_id) + is_done = isin_friendly(input_ids[:, -1], self.eos_token_id) return is_done diff --git a/mindnlp/transformers/models/qwen2/modeling_qwen2.py b/mindnlp/transformers/models/qwen2/modeling_qwen2.py index b26e0b040..66d6a1aaa 100644 --- a/mindnlp/transformers/models/qwen2/modeling_qwen2.py +++ b/mindnlp/transformers/models/qwen2/modeling_qwen2.py @@ -38,7 +38,7 @@ ) from ...modeling_utils import PreTrainedModel from ....utils import logging -from ....configs import SUPPORT_VIEW, use_pyboost +from ....configs import SUPPORT_VIEW, use_pyboost, ON_ORANGE_PI from .configuration_qwen2 import Qwen2Config @@ -119,7 +119,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - if not self.training and use_pyboost(): + if not self.training and use_pyboost() and not ON_ORANGE_PI: return F.rms_norm(hidden_states, self.weight, self.variance_epsilon) input_dtype = hidden_states.dtype hidden_states = hidden_states.to(mindspore.float32) @@ -910,7 +910,12 @@ def forward( else: sequence_lengths = -1 - pooled_logits = logits[ops.arange(batch_size), sequence_lengths] + if ON_ORANGE_PI: + if isinstance(sequence_lengths, mindspore.Tensor): + sequence_lengths = sequence_lengths.to(mindspore.int32) + pooled_logits = ops.getitem(logits, (ops.arange(batch_size), sequence_lengths)) + else: + pooled_logits = logits[ops.arange(batch_size), sequence_lengths] loss = None if labels is not None: diff --git a/mindnlp/transformers/ms_utils.py b/mindnlp/transformers/ms_utils.py index 8bc1f38bb..c1ec59360 100644 --- a/mindnlp/transformers/ms_utils.py +++ b/mindnlp/transformers/ms_utils.py @@ -243,3 +243,17 @@ def meshgrid( Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html """ return ops.meshgrid(*tensors, indexing=indexing) + +def isin_friendly(elements: mindspore.Tensor, test_elements: mindspore.Tensor) -> mindspore.Tensor: + """ + Same as `ops.isin` without flags, but MPS-friendly. + + Args: + elements (`mindspore.Tensor`): Input elements + test_elements (`mindspore.Tensor`): The elements to check against. + + Returns: + `mindspore.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements` + and False otherwise + """ + return elements.tile((test_elements.shape[0], 1)).eq(test_elements.unsqueeze(1)).sum(0).bool().squeeze()