diff --git a/examples/transformers/qwen3_vl/README.md b/examples/transformers/qwen3_vl/README.md
new file mode 100644
index 0000000000..9b8852b950
--- /dev/null
+++ b/examples/transformers/qwen3_vl/README.md
@@ -0,0 +1,83 @@
+# Qwen3-VL series
+
+## Introduction
+[Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks.
+
+# Get Started
+
+## Requirements:
+| mindspore | ascend driver | firmware | cann tookit/kernel |
+|-----------|----------------|----------------|--------------------|
+| 2.6.0 | 24.1.RC3.b080 | 7.5.T11.0.B088 | 8.1.RC1 |
+
+### Installation:
+```
+git clone https://github.com/mindspore-lab/mindone.git -b hf-transformers-4.54
+cd mindone
+pip install -e .
+cd ..
+
+# compile newest transformers whl because qwen3-vl(transformers v4.57.dev.0) haven't released
+git clone https://github.com/huggingface/transformers.git
+cd transformers
+git reset --hard d0af4269ec260b9c4aeeda24c346a469e44799e1
+pip install -e .
+cd ..
+
+cd mindone/examples/transformers/qwen3_vl
+```
+
+## Quick Start
+
+Here is a usage example of Qwen3-VL-4B-Instruct. you can use the following command:
+
+```bash
+# for Qwen3-VL-4B-Instruct inference
+python generate_qwen3_vl.py
+ --model_name "Qwen/Qwen3-VL-4B-Instruct"
+ --image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ --prompt "Describe this image."
+```
+
+```bash
+# for Qwen3-VL-30B-A3B-Instruct inference
+msrun --worker_num=2 --local_worker_num=2 --master_port=8118 \
+ --log_dir=msrun_log --join=True --cluster_time_out=300 \
+ generate_qwen3_vl_moe.py \
+ --model_name "Qwen/Qwen3-VL-30B-A3B-Instruct" \
+ --image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" \
+ --prompt "Describe this image." \
+```
+
+Image:
+
+
+Prompt: Describe this image.
+
+Qwen3-VL-4B Outputs:
+```
+['Of course, here is detailed description of the image provided.\n\n
+ This is a close-up photograph of a Pallas\'s cat ($Felis$, $manul$),
+an endangered wild feline species native to Central Aisa.
+...
+**Appearance:** It has a stocky and robust build with short legs
+and a large head relative to its body size. Its fur is thick and dense,
+appearing somewhat fluffy or "matted,", which is characteristic']
+```
+
+Qwen3-VL-30B Outputs:
+```
+['Of course, here is detailed description of the image provided.\n\n
+This is a dynamic and charming photograph of a Palla's cat (also known as a manul) in a snowy enviroment.
+...
+"Appearance:" The cat has a very distinctive apperance, characterized by its stocky, low-slung body and exceptionally
+thick, dense fur. This coat is a mix of brownish"]
+```
+
+`model_name` and `image` could be replaced with your local path. Give it a try with various images and prompts🤗🤗.
+
+## Inference Speed
+| model name | mindspore version | precision* | cards | attention type | tokens/s |
+|:------------------------------:|:-----------------:|:----------:|:-----:|:--------------:|:----------:|
+| Qwen/Qwen3-VL-4B-Instruct | 2.6.0 | bf16 | 1 | flash_attn | 1.35 |
+| Qwen/Qwen3-VL-30B-A3B-Instruct | 2.6.0 | bf16 | 2 | flash_attn | 0.5 |
diff --git a/examples/transformers/qwen3_vl/generate_qwen3_vl.py b/examples/transformers/qwen3_vl/generate_qwen3_vl.py
new file mode 100644
index 0000000000..15b0a3c46d
--- /dev/null
+++ b/examples/transformers/qwen3_vl/generate_qwen3_vl.py
@@ -0,0 +1,79 @@
+import argparse
+
+import numpy as np
+
+import mindspore as ms
+
+from mindone.transformers import AutoProcessor, Qwen3VLForConditionalGeneration
+
+
+def generate(args):
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
+ args.model_name,
+ mindspore_dtype=ms.bfloat16,
+ attn_implementation=args.attn_implementation,
+ )
+
+ processor = AutoProcessor.from_pretrained(
+ args.model_name,
+ use_fast=False,
+ )
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": args.image,
+ },
+ {
+ "type": "text",
+ "text": args.prompt,
+ },
+ ],
+ }
+ ]
+
+ inputs = processor.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np"
+ )
+
+ # convert input to Tensor
+ for key, value in inputs.items():
+ if isinstance(value, np.ndarray):
+ inputs[key] = ms.tensor(value)
+ elif isinstance(value, list):
+ inputs[key] = ms.Tensor(value)
+
+ generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+ output_text = processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+ print(output_text)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Qwen3VL demo.")
+
+ parser.add_argument("--prompt", type=str, default="Describe this image.")
+ parser.add_argument(
+ "--image",
+ type=str,
+ default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
+ )
+ parser.add_argument(
+ "--model_name", type=str, default="Qwen/Qwen3-VL-4B-Instruct", help="Path to the pre-trained model."
+ )
+ parser.add_argument(
+ "--attn_implementation",
+ type=str,
+ default="flash_attention_2",
+ choices=["flash_attention_2", "eager"],
+ )
+
+ # Parse the arguments
+ args = parser.parse_args()
+
+ generate(args)
diff --git a/examples/transformers/qwen3_vl/generate_qwen3_vl_moe.py b/examples/transformers/qwen3_vl/generate_qwen3_vl_moe.py
new file mode 100644
index 0000000000..caba83c3ad
--- /dev/null
+++ b/examples/transformers/qwen3_vl/generate_qwen3_vl_moe.py
@@ -0,0 +1,91 @@
+import argparse
+from functools import partial
+
+import numpy as np
+
+import mindspore as ms
+import mindspore.mint.distributed as dist
+from mindspore.communication import GlobalComm
+
+from mindone.trainers.zero import prepare_network
+from mindone.transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
+
+
+def generate(args):
+ model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
+ args.model_name,
+ mindspore_dtype=ms.bfloat16,
+ attn_implementation=args.attn_implementation,
+ )
+
+ # use zero3 parallel
+ shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
+ model = shard_fn(model)
+
+ processor = AutoProcessor.from_pretrained(
+ args.model_name,
+ use_fast=False,
+ )
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": args.image,
+ },
+ {
+ "type": "text",
+ "text": args.prompt,
+ },
+ ],
+ }
+ ]
+
+ inputs = processor.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np"
+ )
+
+ # convert input to Tensor
+ for key, value in inputs.items():
+ if isinstance(value, np.ndarray):
+ inputs[key] = ms.tensor(value)
+ elif isinstance(value, list):
+ inputs[key] = ms.Tensor(value)
+
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+ output_text = processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+ print(output_text)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Qwen3VLMoE demo.")
+
+ parser.add_argument("--prompt", type=str, default="Describe this image.")
+ parser.add_argument(
+ "--image",
+ type=str,
+ default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
+ )
+ parser.add_argument(
+ "--model_name", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="Path to the pre-trained model."
+ )
+ parser.add_argument(
+ "--attn_implementation",
+ type=str,
+ default="flash_attention_2",
+ choices=["flash_attention_2", "eager"],
+ )
+
+ # Parse the arguments
+ args = parser.parse_args()
+
+ # set up card communication
+ dist.init_process_group(backend="hccl")
+ ms.set_auto_parallel_context(parallel_mode="data_parallel")
+
+ generate(args)
diff --git a/mindone/models/modules/parallel/__init__.py b/mindone/models/modules/parallel/__init__.py
index 5240aeb9c6..4fb91c0b72 100644
--- a/mindone/models/modules/parallel/__init__.py
+++ b/mindone/models/modules/parallel/__init__.py
@@ -2,6 +2,7 @@
from .conv import Conv1d, Conv2d, Conv3d, Mint_Conv2d, Mint_Conv3d
from .dense import Dense, Linear
+from .moe_text_experts import MoeTextExperts
# {Original MindSpore Cell: New Cell in ZeRO3}
PARALLEL_MODULES = {
@@ -14,4 +15,6 @@
mint.nn.Linear: Linear,
}
+SPECIAL_CASE_FOR_PARALLEL_MODULES = {nn.Cell: MoeTextExperts}
+
__all__ = ["Conv1d", "Conv2d", "Conv3d", "Mint_Conv2d", "Mint_Conv3d", "Dense", "Linear"]
diff --git a/mindone/models/modules/parallel/moe_text_experts.py b/mindone/models/modules/parallel/moe_text_experts.py
new file mode 100644
index 0000000000..2d053e7c05
--- /dev/null
+++ b/mindone/models/modules/parallel/moe_text_experts.py
@@ -0,0 +1,70 @@
+from typing import Literal, Optional
+
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import mint, nn
+from mindspore.communication import get_group_size, get_rank
+from mindspore.communication.management import GlobalComm
+from mindspore.context import ParallelMode
+from mindspore.parallel._utils import _get_parallel_mode
+
+from .param_wrapper import ZeroParamWrapper
+
+
+class MoeTextExperts(nn.Cell):
+ def __init__(
+ self,
+ net: nn.Cell,
+ zero_stage: Literal[0, 1, 2, 3] = 0,
+ optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP,
+ cell_type: Optional[mstype.Type] = None,
+ ):
+ super().__init__(auto_prefix=False)
+ self.net = net
+ self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type)
+
+ def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None):
+ self.param_wrapper_gate_up_proj = nn.Identity()
+ self.param_wrapper_down_proj = nn.Identity()
+ if zero_stage == 3:
+ # Init parallel settings
+ is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
+ op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1
+ op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0
+ self.op_group_size = op_group_size
+ self.op_rank_id = op_rank_id
+ self.param_wrapper_gate_up_proj = ZeroParamWrapper(
+ self.net.gate_up_proj, zero_stage, optimizer_parallel_group, cell_type
+ )
+ if self.param_wrapper_gate_up_proj.need_rewrite:
+ self.net.gate_up_proj.assign_value(
+ Tensor.from_numpy(
+ self.net.gate_up_proj.numpy().reshape(op_group_size, -1, *self.net.gate_up_proj.shape[1:])[
+ op_rank_id
+ ]
+ )
+ )
+ self.param_wrapper_down_proj = ZeroParamWrapper(
+ self.net.down_proj, zero_stage, optimizer_parallel_group, cell_type
+ )
+ if self.param_wrapper_down_proj.need_rewrite:
+ self.net.down_proj.assign_value(
+ Tensor.from_numpy(
+ self.net.down_proj.numpy().reshape(op_group_size, -1, *self.net.down_proj.shape[1:])[op_rank_id]
+ )
+ )
+
+ def construct(self, hidden_states, routing_weights, router_indices):
+ batch_size = hidden_states.shape[0]
+ hidden_states = hidden_states.reshape(-1, self.net.hidden_size) # (num_tokens, hidden_size)
+
+ hidden_states = hidden_states.repeat(self.net.num_experts, 1)
+ hidden_states = hidden_states.view(self.net.num_experts, -1, self.net.hidden_size)
+
+ gate_up = mint.bmm(hidden_states, self.param_wrapper_gate_up_proj(self.net.gate_up_proj))
+ gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
+ next_states = mint.bmm((up * self.net.act_fn(gate)), self.param_wrapper_down_proj(self.net.down_proj))
+ next_states = next_states.reshape(self.net.num_experts, batch_size, -1, self.net.hidden_size)
+ next_states = next_states * routing_weights.swapaxes(0, 1).view(self.net.num_experts, batch_size, -1)[..., None]
+ next_states = next_states.sum(dim=0)
+ return next_states
diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py
index c0e459d69f..ecfd697c0b 100644
--- a/mindone/trainers/zero.py
+++ b/mindone/trainers/zero.py
@@ -10,7 +10,7 @@
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
-from mindone.models.modules.parallel import PARALLEL_MODULES
+from mindone.models.modules.parallel import PARALLEL_MODULES, SPECIAL_CASE_FOR_PARALLEL_MODULES
from .train_step import TrainOneStepWrapper
@@ -471,7 +471,7 @@ def get_cell_dtype(cell):
return None
-def _init_parallel_settings(net, optimizer_parallel_group, parallel_modules=None):
+def _init_parallel_settings(net, optimizer_parallel_group, parallel_modules=None, special_cases_parallel_module=None):
for module, parallel_module in parallel_modules.items():
if isinstance(net, module):
cell_type = get_cell_dtype(net)
@@ -479,6 +479,14 @@ def _init_parallel_settings(net, optimizer_parallel_group, parallel_modules=None
if cell_type is not None:
new_net.to_float(cell_type)
return new_net
+ for module, parallel_module in special_cases_parallel_module.items():
+ if net.trainable_params():
+ if "gate_up_proj" in net.trainable_params()[0].name:
+ cell_type = get_cell_dtype(net)
+ new_net = parallel_module(net, 3, optimizer_parallel_group)
+ if cell_type is not None:
+ new_net.to_float(cell_type)
+ return new_net
return None
@@ -489,14 +497,20 @@ def get_cell_params_fullname_dict(cell: nn.Cell):
return fullname_dict
-def _prepare_network(network: nn.Cell, optimizer_parallel_group: str, parallel_modules=None):
- new_net = _init_parallel_settings(network, optimizer_parallel_group, parallel_modules)
+def _prepare_network(
+ network: nn.Cell, optimizer_parallel_group: str, parallel_modules=None, special_cases_parallel_module=None
+):
+ new_net = _init_parallel_settings(
+ network, optimizer_parallel_group, parallel_modules, special_cases_parallel_module
+ )
if new_net is not None:
return new_net
for name, sub_net in network._cells.items():
if not sub_net:
continue
- new_sub_net = _init_parallel_settings(sub_net, optimizer_parallel_group, parallel_modules)
+ new_sub_net = _init_parallel_settings(
+ sub_net, optimizer_parallel_group, parallel_modules, special_cases_parallel_module
+ )
if new_sub_net is not None:
params_fullname_dict = get_cell_params_fullname_dict(sub_net)
if isinstance(network, (nn.CellList, nn.SequentialCell)):
@@ -515,18 +529,26 @@ def _prepare_network(network: nn.Cell, optimizer_parallel_group: str, parallel_m
param = getattr(sub_net, param_name)
_logger.warning(f"Set param {param.name} parallel_optimizer False, param shape {param.shape}")
param.parallel_optimizer = False
- _prepare_network(sub_net, optimizer_parallel_group, parallel_modules)
+ _prepare_network(sub_net, optimizer_parallel_group, parallel_modules, special_cases_parallel_module)
return network
-def prepare_network(network: nn.Cell, zero_stage: int = 0, optimizer_parallel_group: str = None, parallel_modules=None):
+def prepare_network(
+ network: nn.Cell,
+ zero_stage: int = 0,
+ optimizer_parallel_group: str = None,
+ parallel_modules=None,
+ special_cases_parallel_module=None,
+):
if zero_stage != 3 or _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
_logger.info("No need rewrite network and return original network.")
return network
_logger.info("Rewrite the network, please wait...")
if parallel_modules is None:
parallel_modules = PARALLEL_MODULES
- network = _prepare_network(network, optimizer_parallel_group, parallel_modules)
+ if special_cases_parallel_module is None:
+ special_cases_parallel_module = SPECIAL_CASE_FOR_PARALLEL_MODULES
+ network = _prepare_network(network, optimizer_parallel_group, parallel_modules, special_cases_parallel_module)
return network
diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py
index d5de06f70d..2cda1be9ad 100644
--- a/mindone/transformers/__init__.py
+++ b/mindone/transformers/__init__.py
@@ -21,15 +21,29 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
-__version__ = "4.50.0"
+__version__ = "4.54.1"
import transformers
from packaging import version
+from .cache_utils import (
+ Cache,
+ DynamicCache,
+ EncoderDecoderCache,
+ HybridCache,
+ MambaCache,
+ OffloadedStaticCache,
+ SlidingWindowCache,
+ StaticCache,
+)
+from .feature_extraction_sequence_utils import SequenceFeatureExtractor
+
# Feature Extractor
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .image_processing_base import ImageProcessingMixin
from .image_processing_utils import BaseImageProcessor
+from .image_processing_utils_fast import BaseImageProcessorFast
from .image_utils import ImageFeatureExtractionMixin
+from .masking_utils import AttentionMaskInterface
from .modeling_utils import MSPreTrainedModel, PreTrainedModel
from .models.albert import (
AlbertForMaskedLM,
@@ -1559,3 +1573,20 @@
Qwen2_5OmniToken2WavModel,
)
from .models.vjepa2 import VJEPA2ForVideoClassification, VJEPA2Model, VJEPA2PreTrainedModel
+
+if version.parse(transformers.__version__) >= version.parse("4.57.0"):
+ from .models.qwen3_vl import (
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLModel,
+ Qwen3VLPreTrainedModel,
+ Qwen3VLProcessor,
+ Qwen3VLTextModel,
+ Qwen3VLVisionModel,
+ )
+ from .models.qwen3_vl_moe import (
+ Qwen3VLMoeForConditionalGeneration,
+ Qwen3VLMoeModel,
+ Qwen3VLMoePreTrainedModel,
+ Qwen3VLMoeTextModel,
+ Qwen3VLMoeVisionModel,
+ )
diff --git a/mindone/transformers/activations.py b/mindone/transformers/activations.py
index 293736a143..7976f9b618 100644
--- a/mindone/transformers/activations.py
+++ b/mindone/transformers/activations.py
@@ -18,32 +18,33 @@
import math
from collections import OrderedDict
-import mindspore as ms
-from mindspore import Tensor, mint, nn, ops
+from mindspore import Tensor, mint, nn
class PytorchGELUTanh(nn.Cell):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
- https://arxiv.org/abs/1606.08415.
+ https://huggingface.co/papers/1606.08415.
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
match due to rounding errors.
"""
def construct(self, input: Tensor) -> Tensor:
- return ops.gelu(input, approximate="tanh")
+ return mint.nn.functional.gelu(input, approximate="tanh")
class NewGELUActivation(nn.Cell):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
- the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
"""
def construct(self, input: Tensor) -> Tensor:
return (
- 0.5 * input * (1.0 + ops.tanh(ops.sqrt(Tensor(2.0 / math.pi)) * (input + 0.044715 * ops.pow(input, 3.0))))
+ 0.5
+ * input
+ * (1.0 + mint.tanh(mint.sqrt(Tensor(2.0 / math.pi)) * (input + 0.044715 * mint.pow(input, 3.0))))
).to(input.dtype)
@@ -51,8 +52,8 @@ class GELUActivation(nn.Cell):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
- ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))) This is now written in C in nn.functional
- Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ mint.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * mint.pow(x, 3)))) This is now written in C in nn.functional
+ Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
@@ -60,10 +61,10 @@ def __init__(self, use_gelu_python: bool = False):
if use_gelu_python:
self.act = self._gelu_python
else:
- self.act = ops.gelu
+ self.act = mint.nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
- return input * 0.5 * (1.0 + ops.erf(input / math.sqrt(2.0)))
+ return input * 0.5 * (1.0 + mint.erf(input / math.sqrt(2.0)))
def construct(self, input: Tensor) -> Tensor:
return self.act(input)
@@ -75,7 +76,7 @@ class FastGELUActivation(nn.Cell):
"""
def construct(self, input: Tensor) -> Tensor:
- return 0.5 * input * (1.0 + ops.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
+ return 0.5 * input * (1.0 + mint.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Cell):
@@ -83,25 +84,21 @@ class QuickGELUActivation(nn.Cell):
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
- def __init__(self):
- super(QuickGELUActivation, self).__init__()
- self.sigmoid = nn.Sigmoid()
-
- def construct(self, input):
- return input * self.sigmoid(1.702 * input)
+ def construct(self, input: Tensor) -> Tensor:
+ return input * mint.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Cell):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
- https://arxiv.org/abs/2004.09602.
+ https://huggingface.co/papers/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
- ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
+ ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
"""
def __init__(self, min: float, max: float):
@@ -114,7 +111,7 @@ def __init__(self, min: float, max: float):
self.gelu = get_activation("gelu")
def construct(self, x: Tensor) -> Tensor:
- return ops.clip(self.gelu(x), self.min, self.max)
+ return mint.clip(self.gelu(x), self.min, self.max)
class AccurateGELUActivation(nn.Cell):
@@ -130,30 +127,17 @@ def __init__(self):
self.precomputed_constant = math.sqrt(2 / math.pi)
def construct(self, input: Tensor) -> Tensor:
- return 0.5 * input * (1 + ops.tanh(self.precomputed_constant * (input + 0.044715 * ops.pow(input, 3))))
-
-
-class SiLUActivationFP32(nn.Cell):
- def __init__(self):
- super(SiLUActivationFP32, self).__init__()
- self.sigmoid = nn.Sigmoid()
-
- def construct(self, x):
- _dtype = x.dtype
- x = x.to(ms.float32)
- out = x * self.sigmoid(x)
- out = out.to(_dtype)
- return out
+ return 0.5 * input * (1 + mint.tanh(self.precomputed_constant * (input + 0.044715 * mint.pow(input, 3))))
class MishActivation(nn.Cell):
"""
- See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
def construct(self, input: Tensor) -> Tensor:
- return ops.mish(input)
+ return mint.nn.functional.mish(input)
class LinearActivation(nn.Cell):
@@ -168,24 +152,24 @@ def construct(self, input: Tensor) -> Tensor:
class LaplaceActivation(nn.Cell):
"""
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
- https://arxiv.org/abs/2209.10655
+ https://huggingface.co/papers/2209.10655
Inspired by squared relu, but with bounded range and gradient for better stability
"""
def construct(self, input, mu=0.707107, sigma=0.282095):
input = (input - mu).div(sigma * math.sqrt(2.0))
- return 0.5 * (1.0 + ops.erf(input))
+ return 0.5 * (1.0 + mint.erf(input))
class ReLUSquaredActivation(nn.Cell):
"""
- Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
+ Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668v2
"""
def construct(self, input):
- relu_applied = ops.relu(input)
- squared = ops.square(relu_applied)
+ relu_applied = mint.nn.functional.relu(input)
+ squared = mint.square(relu_applied)
return squared
@@ -205,16 +189,18 @@ def __getitem__(self, key):
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
+ "leaky_relu": (nn.LeakyReLU, {"alpha": 0.01}),
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
- "relu": nn.ReLU,
+ "relu": mint.nn.ReLU,
"relu2": ReLUSquaredActivation,
- "relu6": nn.ReLU6,
- "sigmoid": nn.Sigmoid,
- "silu": SiLUActivationFP32,
- "swish": SiLUActivationFP32,
- "tanh": nn.Tanh,
+ "relu6": mint.nn.ReLU6,
+ "sigmoid": mint.nn.Sigmoid,
+ "silu": mint.nn.SiLU,
+ "swish": mint.nn.SiLU,
+ "tanh": mint.nn.Tanh,
+ "prelu": mint.nn.PReLU,
}
ACT2FN = ClassInstantier(ACT2CLS)
@@ -226,4 +212,12 @@ def get_activation(activation_string):
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
+# For backwards compatibility with: from activations import gelu_python
+gelu_python = get_activation("gelu_python")
+gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
+gelu_fast = get_activation("gelu_fast")
+quick_gelu = get_activation("quick_gelu")
+silu = get_activation("silu")
+mish = get_activation("mish")
+linear_act = get_activation("linear")
diff --git a/mindone/transformers/audio_utils.py b/mindone/transformers/audio_utils.py
index 6059e3ba42..60ca04c591 100644
--- a/mindone/transformers/audio_utils.py
+++ b/mindone/transformers/audio_utils.py
@@ -18,11 +18,180 @@
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
and remove unnecessary dependencies.
"""
-
+import base64
+import io
+import os
import warnings
-from typing import List, Optional, Tuple, Union
+from io import BytesIO
+from typing import Any, List, Optional, Tuple, Union
import numpy as np
+import requests
+
+import mindspore as ms
+
+from .utils import is_librosa_available, is_mindspore_tensor, is_numpy_array, is_soundfile_available, requires_backends
+
+if is_soundfile_available():
+ import soundfile as sf
+
+if is_librosa_available():
+ import librosa
+
+ # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
+ import soxr
+
+
+def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
+ """
+ Loads `audio` to an np.ndarray object.
+
+ Args:
+ audio (`str` or `np.ndarray`):
+ The audio to be loaded to the numpy array format.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate to be used when loading the audio. It should be same as the
+ sampling rate the model you will be using further was trained with.
+ timeout (`float`, *optional*):
+ The timeout value in seconds for the URL request.
+
+ Returns:
+ `np.ndarray`: A numpy array representing the audio.
+ """
+ requires_backends(load_audio, ["librosa"])
+
+ if isinstance(audio, str):
+ # Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
+ if audio.startswith("http://") or audio.startswith("https://"):
+ audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
+ elif os.path.isfile(audio):
+ audio = librosa.load(audio, sr=sampling_rate)[0]
+ elif isinstance(audio, np.ndarray):
+ audio = audio
+ else:
+ raise TypeError(
+ "Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
+ )
+ return audio
+
+
+def load_audio_as(
+ audio: str,
+ return_format: str,
+ timeout: Optional[int] = None,
+ force_mono: bool = False,
+ sampling_rate: Optional[int] = None,
+) -> Union[str, dict[str, Any], io.BytesIO, None]:
+ """
+ Load audio from either a local file path or URL and return in specified format.
+
+ Args:
+ audio (`str`): Either a local file path or a URL to an audio file
+ return_format (`str`): Format to return the audio in:
+ - "base64": Base64 encoded string
+ - "dict": Dictionary with data and format
+ - "buffer": BytesIO object
+ timeout (`int`, *optional*): Timeout for URL requests in seconds
+ force_mono (`bool`): Whether to convert stereo audio to mono
+ sampling_rate (`int`, *optional*): If provided, the audio will be resampled to the specified sampling rate.
+
+ Returns:
+ `Union[str, Dict[str, Any], io.BytesIO, None]`:
+ - `str`: Base64 encoded audio data (if return_format="base64")
+ - `dict`: Dictionary with 'data' (base64 encoded audio data) and 'format' keys (if return_format="dict")
+ - `io.BytesIO`: BytesIO object containing audio data (if return_format="buffer")
+ """
+ # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
+ requires_backends(load_audio_as, ["librosa"])
+
+ if return_format not in ["base64", "dict", "buffer"]:
+ raise ValueError(f"Invalid return_format: {return_format}. Must be 'base64', 'dict', or 'buffer'")
+
+ try:
+ # Load audio bytes from URL or file
+ audio_bytes = None
+ if audio.startswith(("http://", "https://")):
+ response = requests.get(audio, timeout=timeout)
+ response.raise_for_status()
+ audio_bytes = response.content
+ elif os.path.isfile(audio):
+ with open(audio, "rb") as audio_file:
+ audio_bytes = audio_file.read()
+ else:
+ raise ValueError(f"File not found: {audio}")
+
+ # Process audio data
+ with io.BytesIO(audio_bytes) as audio_file:
+ with sf.SoundFile(audio_file) as f:
+ audio_array = f.read(dtype="float32")
+ original_sr = f.samplerate
+ audio_format = f.format
+ if sampling_rate is not None and sampling_rate != original_sr:
+ # Resample audio to target sampling rate
+ audio_array = soxr.resample(audio_array, original_sr, sampling_rate, quality="HQ")
+ else:
+ sampling_rate = original_sr
+
+ # Convert to mono if needed
+ if force_mono and audio_array.ndim != 1:
+ audio_array = audio_array.mean(axis=1)
+
+ buffer = io.BytesIO()
+ sf.write(buffer, audio_array, sampling_rate, format=audio_format.upper())
+ buffer.seek(0)
+
+ if return_format == "buffer":
+ return buffer
+ elif return_format == "base64":
+ return base64.b64encode(buffer.read()).decode("utf-8")
+ elif return_format == "dict":
+ return {
+ "data": base64.b64encode(buffer.read()).decode("utf-8"),
+ "format": audio_format.lower(),
+ }
+
+ except Exception as e:
+ raise ValueError(f"Error loading audio: {e}")
+
+
+AudioInput = Union[
+ np.ndarray,
+ "ms.Tensor",
+ list[np.ndarray],
+ tuple[np.ndarray],
+ list["ms.Tensor"],
+ tuple["ms.Tensor"], # noqa: F821
+]
+
+
+def is_valid_audio(audio):
+ return is_numpy_array(audio) or is_mindspore_tensor(audio)
+
+
+def is_valid_list_of_audio(audio):
+ return audio and all(is_valid_audio(audio_i) for audio_i in audio)
+
+
+def make_list_of_audio(
+ audio: Union[list[AudioInput], AudioInput],
+) -> AudioInput:
+ """
+ Ensure that the output is a list of audio.
+ Args:
+ audio (`Union[list[AudioInput], AudioInput]`):
+ The input audio.
+ Returns:
+ list: A list of audio.
+ """
+ # If it's a list of audios, it's already in the right format
+ if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
+ return audio
+
+ # If it's a single audio, convert it to a list of
+ if is_valid_audio(audio):
+ return [audio]
+
+ raise ValueError("Invalid input type. Must be a single audio or a list of audio")
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
@@ -247,7 +416,8 @@ def mel_filter_bank(
Args:
num_frequency_bins (`int`):
- Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
+ Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier\
+ Transform used to compute the spectrogram).
num_mel_filters (`int`):
Number of mel filters to generate.
min_frequency (`float`):
@@ -271,6 +441,12 @@ def mel_filter_bank(
if norm is not None and norm != "slaney":
raise ValueError('norm must be one of None or "slaney"')
+ if num_frequency_bins < 2:
+ raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2")
+
+ if min_frequency > max_frequency:
+ raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}")
+
# center points of the triangular mel filters
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
@@ -279,7 +455,7 @@ def mel_filter_bank(
if triangularize_in_mel_space:
# frequencies of FFT bins in Hz, but filters triangularized in mel space
- fft_bin_width = sampling_rate / (num_frequency_bins * 2)
+ fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2)
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
filter_freqs = mel_freqs
else:
@@ -978,145 +1154,3 @@ def amplitude_to_db_batch(
spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
return spectrogram
-
-
-def get_mel_filter_banks(
- nb_frequency_bins: int,
- nb_mel_filters: int,
- frequency_min: float,
- frequency_max: float,
- sample_rate: int,
- norm: Optional[str] = None,
- mel_scale: str = "htk",
-) -> np.array:
- warnings.warn(
- "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
- FutureWarning,
- )
- return mel_filter_bank(
- num_frequency_bins=nb_frequency_bins,
- num_mel_filters=nb_mel_filters,
- min_frequency=frequency_min,
- max_frequency=frequency_max,
- sampling_rate=sample_rate,
- norm=norm,
- mel_scale=mel_scale,
- )
-
-
-def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
- """
- In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
- segments called `frames`.
-
- The window length (window_length) defines how much of the signal is contained in each frame, while the hop length
- defines the step between the beginning of each new frame.
-
-
- Args:
- waveform (`np.array` of shape `(sample_length,)`):
- The raw waveform which will be split into smaller chunks.
- hop_length (`int`, *optional*, defaults to 160):
- Step between each window of the waveform.
- fft_window_size (`int`, *optional*, defaults to 400):
- Defines the size of the window.
- center (`bool`, defaults to `True`):
- Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the
- waveform on the left and on the right.
-
- Return:
- framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
- The framed waveforms that can be fed to `np.fft`.
- """
- warnings.warn(
- "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
- FutureWarning,
- )
- frames = []
- for i in range(0, waveform.shape[0] + 1, hop_length):
- if center:
- half_window = (fft_window_size - 1) // 2 + 1
- start = i - half_window if i > half_window else 0
- end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
- frame = waveform[start:end]
- if start == 0:
- padd_width = (-i + half_window, 0)
- frame = np.pad(frame, pad_width=padd_width, mode="reflect")
-
- elif end == waveform.shape[0]:
- padd_width = (0, (i - waveform.shape[0] + half_window))
- frame = np.pad(frame, pad_width=padd_width, mode="reflect")
-
- else:
- frame = waveform[i : i + fft_window_size]
- frame_width = frame.shape[0]
- if frame_width < waveform.shape[0]:
- frame = np.lib.pad(
- frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
- )
- frames.append(frame)
-
- frames = np.stack(frames, 0)
- return frames
-
-
-def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
- """
- Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
- as `torch.stft`.
-
- Args:
- frames (`np.array` of dimension `(num_frames, fft_window_size)`):
- A framed audio signal obtained using `audio_utils.fram_wav`.
- windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
- A array representing the function that will be used to reduces the amplitude of the discontinuities at the
- boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function.
- For more information on the discontinuities, called *Spectral leakage*, refer to [this
- tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf
- fft_window_size (`int`, *optional*):
- Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the
- spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of
- frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to
- `(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally.
-
- Example:
-
- ```python
- >>> from transformers.audio_utils import stft, fram_wave
- >>> import numpy as np
-
- >>> audio = np.random.rand(50)
- >>> fft_window_size = 10
- >>> hop_length = 2
- >>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
- >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
- ```
-
- Returns:
- spectrogram (`np.ndarray`):
- A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
- """
- warnings.warn(
- "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
- FutureWarning,
- )
- frame_size = frames.shape[1]
-
- if fft_window_size is None:
- fft_window_size = frame_size
-
- if fft_window_size < frame_size:
- raise ValueError("FFT size must greater or equal the frame size")
- # number of FFT bins to store
- nb_frequency_bins = (fft_window_size >> 1) + 1
-
- spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
- fft_signal = np.zeros(fft_window_size)
-
- for f, frame in enumerate(frames):
- if windowing_function is not None:
- np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
- else:
- fft_signal[:frame_size] = frame
- spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
- return spectrogram.T
diff --git a/mindone/transformers/cache_utils.py b/mindone/transformers/cache_utils.py
index d24fda4457..e7ee8f09ab 100644
--- a/mindone/transformers/cache_utils.py
+++ b/mindone/transformers/cache_utils.py
@@ -4,16 +4,20 @@
Cache utils.
"""
import copy
+import functools
+import inspect
import json
import os
-from typing import Any, Dict, List, Optional, Tuple, Union
+from abc import ABC, abstractmethod
+from collections.abc import Iterable
+from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
import mindspore as ms
-from mindspore import mint, nn, ops
+from mindspore import mint, ops
logger = logging.get_logger(__name__)
@@ -29,7 +33,6 @@ def init_static_cache(config: PretrainedConfig, max_batch_size: int, max_cache_l
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
dtype = dtype if dtype is not None else ms.float32
-
if hasattr(config, "num_key_value_heads"):
num_key_value_heads = config.num_key_value_heads
else:
@@ -147,315 +150,752 @@ def reset(past_key_values):
return past_key_values
-class Cache(nn.Cell):
+class CacheLayerMixin(ABC):
+ """Base, abstract class for a single layer's cache."""
+
+ is_compileable = False
+
+ def __init__(self):
+ self.keys, self.values = None, None
+
+ @abstractmethod
+ def update(
+ self,
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
+ ...
+
+ @abstractmethod
+ def get_seq_length(self, cache_position=None) -> int:
+ ...
+
+ @abstractmethod
+ def get_max_cache_shape(self) -> int:
+ ...
+
+ @abstractmethod
+ def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]:
+ ...
+
+ def reset(self) -> None:
+ """Resets the cache values while preserving the objects"""
+ self.keys.zero_()
+ self.values.zero_()
+
+ def reorder_cache(self, beam_idx: ms.Tensor) -> tuple[ms.Tensor, ms.Tensor]:
+ """Reorders this layer's cache for beam search."""
+ if self.keys.numel():
+ self.keys = self.keys.index_select(0, beam_idx)
+ if self.values.numel():
+ self.values = self.values.index_select(0, beam_idx)
+
+
+class DynamicLayer(CacheLayerMixin):
"""
- Base, abstract class for all caches. The actual data structure is specific to each subclass.
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
+ It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
+
+ See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
"""
- is_compileable = False
+ is_sliding = False
def update(
self,
key_states: ms.Tensor,
value_states: ms.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[ms.Tensor, ms.Tensor]:
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
"""
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Updates the cache with the new `key_states` and `value_states`.
Parameters:
key_states (`ms.Tensor`):
The new key states to cache.
value_states (`ms.Tensor`):
The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
- cache to be created.
+ cache_kwargs (`dict[str, Any]`, *optional*):
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`.
Return:
A tuple containing the updated key and value states.
"""
- raise NotImplementedError("Make sure to implement `update` in a subclass.")
+ if self.keys is None:
+ self.keys = key_states
+ self.values = value_states
+ else:
+ self.keys = mint.cat([self.keys, key_states], dim=-2)
+ self.values = mint.cat([self.values, value_states], dim=-2)
+ return self.keys, self.values
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # TODO: deprecate this function in favor of `cache_position`
- raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
-
- def get_max_length(self) -> Optional[int]:
- """Returns the maximum sequence length of the cached states, if there is any."""
- raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
-
- def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
- """Given the sequence length of the new inputs, returns the usable length of the cache."""
- # Cache without size limit -> all cache is usable
- # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
- # length, we will need to evict part of the cache (and thus not all cache is usable)
- max_length = self.get_max_length()
- previous_seq_length = self.get_seq_length(layer_idx)
- if max_length is not None and previous_seq_length + new_seq_length > max_length:
- return max_length - new_seq_length
- return previous_seq_length
+ def get_seq_length(self, cache_position=None) -> int:
+ """Returns the sequence length of the cached states."""
+ if self.keys is None or self.keys.numel() == 0:
+ return 0
+ return self.keys.shape[-2]
- def reorder_cache(self, beam_idx: ms.Tensor):
+ def get_max_cache_shape(self) -> int:
+ """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
+ return -1
+
+ def reorder_cache(self, beam_idx: ms.Tensor) -> None:
"""Reorders the cache for beam search, given the selected beam indices."""
- for layer_idx in range(len(self.key_cache)):
- if self.key_cache[layer_idx] != []:
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx)
- if self.value_cache[layer_idx] != []:
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx)
+ if self.keys is not None and self.keys.numel():
+ self.keys = self.keys.index_select(0, beam_idx)
+ self.values = self.values.index_select(0, beam_idx)
- @property
- def seen_tokens(self):
- if hasattr(self, "_seen_tokens"):
- return self._seen_tokens
- else:
- return None
+ def crop(self, max_length: int) -> None:
+ """
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
+ negative to remove `max_length` tokens.
+ """
+ if max_length < 0:
+ max_length = self.get_seq_length() - abs(max_length)
+
+ if self.get_seq_length() <= max_length:
+ return
+ if self.keys is not None and self.keys.numel():
+ self.keys = self.keys[..., :max_length, :]
+ self.values = self.values[..., :max_length, :]
+
+ def batch_repeat_interleave(self, repeats: int) -> None:
+ """Repeat the cache `repeats` times in the batch dimension."""
+ if self.keys is not None and self.keys.numel():
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
+ self.values = self.values.repeat_interleave(repeats, dim=0)
+
+ def batch_select_indices(self, indices: ms.Tensor) -> None:
+ """Only keep the `indices` in the batch dimension of the cache."""
+ if self.keys is not None and self.keys.numel():
+ self.keys = self.keys[indices, ...]
+ self.values = self.values[indices, ...]
+
+ def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]:
+ """Return the length and offset of the cache, used to generate the mask"""
+ kv_offset = 0
+ query_length = cache_position.shape[0]
+ past_seen_tokens = self.get_seq_length()
+ kv_length = query_length + past_seen_tokens
+ return kv_length, kv_offset
-class StaticCache(Cache):
+ @classmethod
+ def from_tensors(cls, keys: ms.Tensor, values: ms.Tensor) -> "DynamicLayer":
+ """
+ Build a `DynamicLayer` instance from pre-existing key/value tensors.
+
+ Args:
+ keys (`ms.Tensor`):
+ Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
+ values (`ms.Tensor`):
+ Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
+
+ Returns:
+ `DynamicLayer`: The newly constructed layer whose internal cache directly references
+ the supplied tensors.
+ """
+ layer = cls()
+ layer.keys = keys
+ layer.values = values
+ return layer
+
+
+class StaticLayer(CacheLayerMixin):
"""
- Static Cache class to be used with `static shape`.
+ A static cache layer that stores the Key and Value states as static tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
+ It allocates its full backing tensors up-front and mutates them in-place. Built for `mindspore.jit` support.
- Parameters:
- config (`PretrainedConfig):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- max_batch_size (`int`):
- The maximum batch size with which the model will be used.
- max_cache_len (`int`):
- The maximum sequence length with which the model will be used.
- dtype (*optional*, defaults to `ms.float32`):
- The default `dtype` to use when initializing the layer.
+ See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
"""
is_compileable = True
+ is_sliding = False
- def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, dtype=None) -> None:
- super().__init__()
- self.max_batch_size = max_batch_size
- self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
- # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
- self.head_dim = (
- config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- )
+ def __init__(
+ self,
+ max_cache_len: int,
+ batch_size: int,
+ num_heads: int,
+ head_dim: int,
+ dtype: ms.Type = ms.float32,
+ sliding_window: Optional[int] = None,
+ ):
+ """
+ Args:
+ max_cache_len (`int`):
+ Maximum number of tokens that can be stored, used for tensor preallocation.
+ batch_size (`int`):
+ Maximum batch size the cache is pre-allocated for.
+ num_heads (`int`):
+ Number of attention heads.
+ head_dim (`int`):
+ Per-head hidden dimension.
+ dtype (`ms.Type`, defaults to `ms.float32`):
+ Data type of the cache tensors.
+
+ Notes:
+ Static layers allocate their full backing tensors up-front and mutate them
+ in-place. See the documentation of `Cache` for shared helper methods that
+ operate uniformly across all layer types.
+ """
+ self.max_cache_len = max_cache_len
+ self.max_batch_size = batch_size
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.dtype = dtype
- self.dtype = dtype if dtype is not None else ms.float32
- self.num_key_value_heads = (
- config.num_attention_heads
- if getattr(config, "num_key_value_heads", None) is None
- else config.num_key_value_heads
+ self.keys = mint.zeros(
+ (batch_size, num_heads, self.max_cache_len, head_dim),
+ dtype=dtype,
)
+ self.values = mint.zeros(
+ (batch_size, num_heads, self.max_cache_len, head_dim),
+ dtype=dtype,
+ )
+ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
+ # preventing compiled graph breaks when updating the cache.
+ # fixme there is no implementation for torch._dynamo.mark_static_address
- key_cache: List[ms.Parameter] = []
- value_cache: List[ms.Parameter] = []
- cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
- for _layer_index in range(config.num_hidden_layers):
- # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
- # breaks when updating the cache.
- new_layer_key_cache = ms.Parameter(
- ms.Tensor(np.zeros(cache_shape), dtype=self.dtype),
- name=f"key_cache_{_layer_index}",
- requires_grad=False,
- )
- new_layer_value_cache = ms.Parameter(
- ms.Tensor(np.zeros(cache_shape), dtype=self.dtype),
- name=f"value_cache_{_layer_index}",
- requires_grad=False,
- )
- key_cache.append(new_layer_key_cache)
- value_cache.append(new_layer_value_cache)
-
- self.key_cache = ms.ParameterTuple(key_cache)
- self.value_cache = ms.ParameterTuple(value_cache)
+ def get_max_cache_shape(self) -> int:
+ """Return the maximum cache shape of the cache"""
+ return self.max_cache_len
def update(
self,
key_states: ms.Tensor,
value_states: ms.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[ms.Tensor, ms.Tensor]:
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
"""
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Update the static cache tensors in place.
- Parameters:
- key_states (`ms.Tensor`):
- The new key states to cache.
- value_states (`ms.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
- to know how where to write in the cache.
+ Args:
+ key_states (`ms.Tensor`): The new key states to cache.
+ value_states (`ms.Tensor`): The new value states to cache.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
- Return:
- A tuple containing the updated key and value states.
+ Returns:
+ tuple[`ms.Tensor`, `ms.Tensor`]: The updated key and value states.
"""
- cache_position = cache_kwargs.get("cache_position")
- k_out = self.key_cache[layer_idx]
- v_out = self.value_cache[layer_idx]
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None
+ key_states = key_states.to(self.keys.dtype)
+ value_states = value_states.to(self.values.dtype)
if cache_position is None:
- k_out.copy_(key_states)
- v_out.copy_(value_states)
+ # Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
+ self.keys.copy_(key_states)
+ self.values.copy_(value_states)
else:
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
-
- return k_out, v_out
-
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states that were seen by the model."""
+ # Generation phase. Update specific positions.
+ # Use index_copy_ for in-place update (compile-friendly).
+ try:
+ self.keys.index_copy_(2, cache_position, key_states)
+ self.values.index_copy_(2, cache_position, value_states)
+ except Exception: # MindSpore does not support index_copy_
+ # Fallback for devices like MPS where index_copy_ might not be supported.
+ self.keys[:, :, cache_position] = key_states
+ self.values[:, :, cache_position] = value_states
+ return self.keys, self.values
+
+ def get_seq_length(self, cache_position=None) -> int:
+ """Returns the sequence length of the cached states."""
+ if cache_position is not None:
+ return int(cache_position[-1] + 1)
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
- # TODO: deprecate this function in favor of `cache_position`
- return (self.key_cache[layer_idx][0, 0].any(axis=-1)).sum()
-
- def get_max_length(self) -> Optional[int]:
- # FIXME: deprecated function, should use get_max_cache_shape instead. Keep it for compatibility.
- """Returns the maximum sequence length of the cached states."""
- return self.max_cache_len
+ seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
+ return seq_length
- def get_max_cache_shape(self) -> Optional[int]:
- return self.max_cache_len
+ def reorder_cache(self, beam_idx: ms.Tensor) -> None:
+ """Reorders the cache for beam search, given the selected beam indices."""
+ self.keys = self.keys.index_select(0, beam_idx)
+ self.values = self.values.index_select(0, beam_idx)
- def reset(self):
- """Resets the cache values while preserving the objects"""
- for layer_idx in range(len(self.key_cache)):
- # In-place ops prevent breaking the static address
- ops.assign(self.key_cache[layer_idx], ms.Tensor(0.0))
- ops.assign(self.value_cache[layer_idx], ms.Tensor(0.0))
+ def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]:
+ """Return the length and offset of the cache, used to generate the attention mask"""
+ kv_offset = 0
+ kv_length = self.max_cache_len
+ return kv_length, kv_offset
-class CacheConfig:
+class SlidingWindowLayer(StaticLayer):
"""
- Base class for cache configs
+ A static cache layer that implements sliding window attention caching.
+
+ See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
"""
- cache_implementation: None
+ is_sliding = True
- @classmethod
- def from_dict(cls, config_dict, **kwargs):
+ def __init__(self, sliding_window, *args, **kwargs):
"""
- Constructs a CacheConfig instance from a dictionary of parameters.
Args:
- config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
- **kwargs: Additional keyword arguments to override dictionary values.
- Returns:
- CacheConfig: Instance of CacheConfig constructed from the dictionary.
+ sliding_window (`int`):
+ Effective window size: number of tokens that are kept on each update call.
"""
- config = cls(**config_dict)
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(config, key):
- setattr(config, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
- return config
+ max_cache_len = kwargs.pop("max_cache_len", None)
+ max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window
+ super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs)
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ def update(
+ self,
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
"""
- Save this instance to a JSON file.
+ Update the sliding window cache tensors in place.
Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this configuration instance's parameters will be saved.
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default
- `QuantizationConfig()` is serialized to JSON file.
+ key_states (`ms.Tensor`): The new key states to cache.
+ value_states (`ms.Tensor`): The new value states to cache.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+ Returns:
+ tuple[`ms.Tensor`, `ms.Tensor`]: The updated key and value states.
"""
- with open(json_file_path, "w", encoding="utf-8") as writer:
- config_dict = self.to_dict()
- json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None
+ if cache_position is None:
+ raise ValueError("`cache_position` must be provided for SlidingWindowLayer.")
- writer.write(json_string)
+ key_states = key_states.to(self.keys.dtype)
+ value_states = value_states.to(self.values.dtype)
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
- def to_dict(self) -> Dict[str, Any]:
+ # Handle prefill phase when prompt length > sliding_window_size.
+ # Note that we store cropped key/value states in the cache but return the full key/value states.
+ if cache_position.shape[0] > self.max_cache_len:
+ new_k = key_states[:, :, -self.max_cache_len :, :]
+ new_v = value_states[:, :, -self.max_cache_len :, :]
+ self.keys.copy_(new_k)
+ self.values.copy_(new_v)
+ return key_states, value_states
+
+ # Sliding window logic for generation phase or prefill < window
+ slicing = mint.arange(self.max_cache_len)
+ current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
+ to_shift = current_seq_len > self.max_cache_len
+ indices = (slicing + to_shift.sum()) % self.max_cache_len
+
+ k_out_shifted = self.keys[:, :, indices]
+ v_out_shifted = self.values[:, :, indices]
+
+ # Clamp cache_position to determine the *target index* within the shifted cache view
+ update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1)
+
+ try:
+ k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
+ v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
+ except Exception: # MindSpore does not support index_copy_
+ # Fallback for MPS: clone and modify the clone
+ k_out_updated = k_out_shifted.clone()
+ v_out_updated = v_out_shifted.clone()
+ k_out_updated[:, :, update_position] = key_states
+ v_out_updated[:, :, update_position] = value_states
+
+ self.keys.copy_(k_out_updated)
+ self.values.copy_(v_out_updated)
+ return self.keys, self.values
+
+ def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]:
+ """Return the length and offset of the cache, used to generate the attention mask"""
+ query_length = cache_position.shape[0]
+ first_cache_position = cache_position[0]
+
+ kv_offset = mint.clamp(first_cache_position - self.max_cache_len + 1, min=0)
+ # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
+ kv_length = max(query_length, self.max_cache_len)
+ return kv_length, kv_offset
+
+
+class ChunkedSlidingLayer(SlidingWindowLayer):
+ """
+ An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4.
+
+ See `SlidingWindowLayer` for details on common methods that are implemented by all cache layers.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.cumulative_length = 0
+
+ def update(
+ self,
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None
+ if cache_position is None:
+ raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.")
+
+ cumulative_length = self.cumulative_length
+ self.cumulative_length += key_states.shape[-2]
+ is_full = cumulative_length >= self.max_cache_len
+
+ if is_full:
+ full_key_states = mint.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
+ full_value_states = mint.cat((self.values[:, :, 1:, :], value_states), dim=-2)
+ # Fast decoding path -> here as the effective size is still sliding window, it is extremely important
+ # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address
+ # in memory (the values are the same as the full states, but not the address!!)
+ if key_states.shape[-2] == 1:
+ self.keys.copy_(full_key_states)
+ self.values.copy_(full_value_states)
+ return self.keys, self.values
+ elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len:
+ if cumulative_length == 0:
+ full_key_states = key_states
+ full_value_states = value_states
+ else:
+ full_key_states = mint.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
+ full_value_states = mint.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
+ else:
+ try:
+ self.keys.index_copy_(2, cache_position, key_states)
+ self.values.index_copy_(2, cache_position, value_states)
+ except Exception: # MindSpore does not support index_copy_
+ self.keys[:, :, cache_position] = key_states
+ self.values[:, :, cache_position] = value_states
+ return self.keys, self.values
+
+ self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
+ self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
+ return full_key_states, full_value_states
+
+ def reset(self) -> None:
+ super().reset()
+ self.cumulative_length = 0
+
+ def get_mask_sizes(self, cache_position: ms.Tensor) -> tuple[int, int]:
+ query_length = cache_position.shape[0]
+ first_cache_position = cache_position[0]
+ sliding_window = self.max_cache_len
+
+ kv_offset = mint.clamp(first_cache_position - sliding_window + 1, min=0)
+ # This is the true general case for any Cache using local attention (sliding or chunked)
+ if first_cache_position >= sliding_window:
+ # Here the Cache is already full
+ kv_length = sliding_window + query_length - 1
+ elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window:
+ # Here the Cache becomes full with the new input
+ kv_length = first_cache_position + query_length
+ else:
+ # Here the Cache is still smaller than the local size, but we return the local size as it's static
+ kv_length = sliding_window
+ return kv_length, kv_offset
+
+
+class CacheProcessor:
+ """
+ Base class for cache processors. It defines a pre-update and post-update methods that are called before and after the cache update.
+ This class should be subclassed.
+ """
+
+ def __init__(self, cache: "Cache", **kwargs) -> None:
"""
- Serializes this instance to a Python dictionary. Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ Initialize the processor and perform compatibility checks with the cache.
+
+ Args:
+ cache (`Cache`): The cache instance this processor will be applied to.
+ **kwargs: Additional arguments that may be needed for initialization.
"""
- return copy.deepcopy(self.__dict__)
+ raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.")
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
- def __iter__(self):
- """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
- for attr, value in copy.deepcopy(self.__dict__).items():
- yield attr, value
+ def pre_update(
+ self,
+ cache: "Cache",
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
+ """
+ Function called before the cache update. Can modify the key/value states.
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
+ Args:
+ cache (`Cache`): The cache instance.
+ key_states (`ms.Tensor`): The new key states to cache.
+ value_states (`ms.Tensor`): The new value states to cache.
+ layer_idx (`int`): The index of the layer to cache the states for.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
- def to_json_string(self):
- """
- Serializes this instance to a JSON formatted string.
Returns:
- str: JSON formatted string representing the configuration instance.
+ The modified key and value states.
"""
- return json.dumps(self.__dict__, indent=2) + "\n"
+ return key_states, value_states
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
- def update(self, **kwargs):
+ def post_update(
+ self,
+ cache: "Cache",
+ key_tensors: ms.Tensor,
+ value_tensors: ms.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
"""
- Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
- returning all the unused kwargs.
+ Function called after the cache update. Can process the cached data.
Args:
- kwargs (`Dict[str, Any]`):
- Dictionary of attributes to tentatively update this class.
+ cache (`Cache`): The cache instance.
+ key_states (`ms.Tensor`): The key states that were cached.
+ value_states (`ms.Tensor`): The value states that were cached.
+ layer_idx (`int`): The index of the layer that was updated.
+ cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
- `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ The final key and value states to return to the model.
"""
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(self, key):
- setattr(self, key, value)
- to_remove.append(key)
+ return key_tensors, value_tensors
- # Remove all the attributes that were updated, without modifying the input dict
- unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
- return unused_kwargs
+class OffloadedCacheProcessor(CacheProcessor):
+ """
+ A cache processor that offloads cache tensors to conserve accelerator memory.
-class DynamicCache(Cache):
+ This processor manages moving cache tensors between accelerator and CPU memory,
+ using asynchronous prefetching to minimize performance impact. Works with both
+ dynamic and static layers.
"""
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
- `[batch_size, num_heads, seq_len, head_dim]`.
+ def __init__(self, cache: "Cache", **kwargs):
+ raise NotImplementedError
+
+
+class QuantizedCacheProcessor(CacheProcessor):
+ """
+ A cache processor that applies quantization to cache tensors to reduce memory usage.
+
+ This processor quantizes cache tensors after they are stored, maintaining a residual
+ length in original precision and quantizing older tokens.
+ """
+
+ def __init__(
+ self,
+ cache: "Cache",
+ backend: str = "quanto",
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ compute_dtype: ms.Type = ms.float16,
+ ):
+ """
+ Parameters:
+ backend (`str`, defaults to `"quanto"`):
+ Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
+ nbits (`int`, defaults to 4):
+ Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
+ axis_key (`int`, defaults to 0):
+ Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
+ axis_value (`int`, defaults to 0):
+ Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
+ q_group_size (`int`, defaults to 64):
+ Size of the quantization group, should be a divisor of the model's hidden dimension.
+ Defaults to 64.
+ residual_length (`int`, defaults to 128):
+ Length of the residual cache which will always be stored in original precision.
+ Defaults to 128.
+ compute_dtype (`ms.Type`, defaults to `ms.float16`):
+ The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
+ """
+ raise NotImplementedError
+
+
+class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor):
+ """
+ Quantized cache processor that uses `quanto` as a backend to perform quantization.
+ Current implementation supports `int2` and `int4` dtypes only.
"""
- def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
- # in hf transformers there is no `num_hidden_layers` but `_distributed_cache_data`
- # it was originally added for compatibility with `torch.distributed` (DDP). See #36121
- # in mindspore there is no DDP, so we keep `num_hidden_layers`
- super().__init__()
- if num_hidden_layers is None:
- self.key_cache: List[ms.Tensor] = []
- self.value_cache: List[ms.Tensor] = []
+ def __init__(
+ self,
+ cache: "Cache",
+ backend: str = "quanto",
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ compute_dtype: ms.Type = ms.float16,
+ ) -> None:
+ """Initialize the quanto quantization processor."""
+ super().__init__(cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype)
+
+ raise NotImplementedError
+
+
+class HQQQuantizedCacheProcessor(QuantizedCacheProcessor):
+ """
+ Quantized cache processor that uses `HQQ` as a backend to perform quantization.
+ Current implementation supports `int2`, `int4`, `int8` dtypes.
+ """
+
+ def __init__(
+ self,
+ cache: "Cache",
+ backend: str = "quanto",
+ nbits: int = 4,
+ axis_key: int = 0,
+ axis_value: int = 0,
+ q_group_size: int = 64,
+ residual_length: int = 128,
+ compute_dtype: ms.Type = ms.float16,
+ ) -> None:
+ """Initialize the HQQ quantization processor."""
+ super().__init__(cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype)
+ raise NotImplementedError
+
+
+def apply_processors(
+ fn: Callable[..., tuple[ms.Tensor, ms.Tensor]],
+) -> Callable[..., tuple[ms.Tensor, ms.Tensor]]:
+ @functools.wraps(fn)
+ def _wrapped_update(
+ self,
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
+ """
+ Wrapper around the update method to apply cache processors.
+ """
+ if self.cache_processor is not None:
+ key_states, value_states = self.cache_processor.pre_update(
+ self, key_states, value_states, layer_idx, cache_kwargs
+ )
+
+ key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs)
+
+ if self.cache_processor is not None:
+ key_tensors, value_tensors = self.cache_processor.post_update(
+ self, key_tensors, value_tensors, layer_idx, cache_kwargs
+ )
+
+ return key_tensors, value_tensors
+
+ return _wrapped_update
+
+
+class KeyValuesWrapper:
+ """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache.
+ This allows for BC access and writing, e.g., cache.key_cache[idx] = ...
+ Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0"""
+
+ def __init__(self, layers, cache_type="keys"):
+ self.layers = layers
+ self.cache_type = cache_type
+
+ def __getitem__(self, idx):
+ if isinstance(idx, slice):
+ return [getattr(layer, self.cache_type) for layer in self.layers[idx]]
+ return getattr(self.layers[idx], self.cache_type)
+
+ def __setitem__(self, idx, value):
+ if isinstance(idx, slice):
+ for layer, val in zip(self.layers[idx], value):
+ setattr(layer, self.cache_type, val)
else:
- self.key_cache: List[ms.Tensor] = [[] for _ in range(num_hidden_layers)]
- self.value_cache: List[ms.Tensor] = [[] for _ in range(num_hidden_layers)]
- self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
+ setattr(self.layers[idx], self.cache_type, value)
+
+ def __len__(self):
+ return len(self.layers)
+
+ def __iter__(self):
+ for layer in self.layers:
+ yield getattr(layer, self.cache_type)
+
+ def __bool__(self):
+ return bool(self.layers)
+
- def __getitem__(self, layer_idx: int) -> List[Tuple[ms.Tensor]]:
+class Cache:
+ """
+ Base container for per-layer key/value caches.
+
+ A `Cache` behaves like a list of `CacheLayerMixin` objects, one per model layer.
+ Sub-classes such as `DynamicCache`, `StaticCache`, or `SlidingWindowCache`
+ simply pre-select which `CacheLayerMixin` class to use and may attach a
+ `CacheProcessor` (off-loading, quantization).
+
+ Example
+ -------
+ ```python
+ from mindone.transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
+
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
+ tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+ inputs = tok("Hello", return_tensors="np")
+ for key in inputs.keys():
+ inputs[key] = ms.tensor(inputs[key])
+
+ cache = DynamicCache()
+ outputs = model(**inputs, past_key_values=cache, use_cache=True)
+ ```
+
+ Parameters:
+ layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`):
+ A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is
+ provided, then it is used for all layers.
+ config (`PretrainedConfig`, *optional*):
+ Model configuration used to infer number of layers, head sizes, default
+ device/dtype, etc.
+ cache_processor (`CacheProcessor` or `str`, *optional*):
+ Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized")
+ or a CacheProcessor class.
+ max_batch_size (`int`, *optional*): Maximum batch size for static caches.
+ max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are
+ clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`.
+ dtype (`ms.Type`, *optional*): Data type for cache tensors.
+ tp_size (`int`, *optional*): Tensor parallel size to adjust the number of key/value heads.
+
+ Additional keyword arguments are forwarded to the chosen layers constructor(s) and CacheProcessors. See the
+ documentation of the relevant `CacheLayerMixin` class and `CacheProcessor` class for more details.
+ """
+
+ def __init__(
+ self,
+ layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]],
+ config: Optional[PretrainedConfig] = None,
+ cache_processor: Optional[Union[str, type[CacheProcessor]]] = None,
+ max_batch_size: Optional[int] = None,
+ max_cache_len: Optional[int] = None,
+ dtype: Optional[ms.Type] = None,
+ tp_size: Optional[int] = None,
+ **kwargs,
+ ):
+ self.layers: list[CacheLayerMixin] = []
+ self.layer_classes = layer_classes
+
+ processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
+ kwargs.update(
+ max_batch_size=max_batch_size,
+ max_cache_len=max_cache_len,
+ dtype=dtype,
+ tp_size=tp_size,
+ )
+ processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs)
+
+ self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs)
+ self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
+
+ self.append_new_layers(self.num_hidden_layers - 1)
+ self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None
+
+ def __getitem__(self, layer_idx: int) -> tuple[ms.Tensor, ms.Tensor]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
- if layer_idx < len(self):
- return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+ if layer_idx < len(self.layers):
+ return self.layers[layer_idx].keys, self.layers[layer_idx].values
else:
- raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+ raise KeyError(
+ f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
+ )
def __iter__(self):
"""
@@ -463,22 +903,58 @@ def __iter__(self):
keys and values
"""
for layer_idx in range(len(self)):
- yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+ yield (self.layers[layer_idx].keys, self.layers[layer_idx].values)
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
- return len(self.key_cache)
+ # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__
+ if getattr(self, "layers", None) is None:
+ if getattr(self, "key_cache", None) is not None:
+ return len(self.key_cache)
+ return 0
+ # Empty dynamic caches initialize an empty layer to be ready for first update
+ dynamic_empty = (
+ getattr(self, "layers", None) is not None
+ and len(self.layers) == 1
+ and isinstance(self.layers[0], DynamicLayer)
+ and self.layers[0].keys is None
+ )
+ return len(self.layers) if not dynamic_empty else 0
+ def __repr__(self):
+ return f"{self.__class__.__name__}(layers={self.layers})"
+
+ def append_new_layers(self, layer_idx: int) -> None:
+ """
+ Appends layers to the cache until the layer `layer_idx` is reached.
+ Used for preallocation in static caches and on the fly in dynamic caches.
+
+ Args:
+ layer_idx (`int`):
+ The index of the layer to append.
+ """
+ while len(self.layers) <= layer_idx:
+ kwargs = self.layer_init_kwargs.copy()
+ if self.layer_init_kwargs.get("layer_device_map", None) is not None:
+ kwargs["device"] = kwargs.pop("layer_device_map")[layer_idx]
+
+ new_layer_class = (
+ self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes
+ )
+ new_layer = new_layer_class(**kwargs)
+ self.layers.append(new_layer)
+
+ @apply_processors
def update(
self,
key_states: ms.Tensor,
value_states: ms.Tensor,
layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[ms.Tensor, ms.Tensor]:
+ cache_kwargs: Optional[dict[str, Any]] = None,
+ ) -> tuple[ms.Tensor, ms.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -489,58 +965,164 @@ def update(
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+ cache_kwargs (`dict[str, Any]`, *optional*):
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+ cache to be created.
Return:
A tuple containing the updated key and value states.
"""
- # Update the number of seen tokens
- if layer_idx == 0:
- self._seen_tokens += key_states.shape[-2]
-
- # Update the cache
- if len(self.key_cache) <= layer_idx:
- # There may be skipped layers, fill them with empty lists
- for _ in range(len(self.key_cache), layer_idx):
- self.key_cache.append([])
- self.value_cache.append([])
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
- # content on layer cache can be a tensor and checking not tensor causes errors
- # so we explicitly check for the empty list
- elif len(self.key_cache[layer_idx]) == 0:
- self.key_cache[layer_idx] = key_states
- self.value_cache[layer_idx] = value_states
- else:
- self.key_cache[layer_idx] = ops.cat([self.key_cache[layer_idx], key_states], axis=-2)
- self.value_cache[layer_idx] = ops.cat([self.value_cache[layer_idx], value_states], axis=-2)
-
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ self.append_new_layers(layer_idx)
+ return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # TODO: deprecate this function in favor of `cache_position`
- if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
+ def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int:
+ """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position"""
+ if layer_idx >= len(self.layers):
return 0
- return self.key_cache[layer_idx].shape[-2]
+ # Hack since QuantizedCache messes with keys shape as it becomes the residual cache
+ if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor):
+ return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length(cache_position)
+ return self.layers[layer_idx].get_seq_length(cache_position)
+
+ def get_mask_sizes(self, cache_position: ms.Tensor, layer_idx: int) -> tuple[int, int]:
+ """
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
+ the given layer at `layer_idx`.
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
+ for each layer.
+ """
+ kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position)
+ return kv_length, kv_offset
+
+ @property
+ def key_cache(self) -> KeyValuesWrapper:
+ """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`"""
+ logger.warning_once(
+ "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead."
+ )
+ return KeyValuesWrapper(self.layers, "keys")
+
+ @property
+ def value_cache(self) -> KeyValuesWrapper:
+ """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`"""
+ logger.warning_once(
+ "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead."
+ )
+ return KeyValuesWrapper(self.layers, "values")
+
+ # Wrappers for layer operations and properties ###
+
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
+ """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
+ return self.layers[layer_idx].get_max_cache_shape()
+
+ def reset(self):
+ """Recursively reset all layers tensors"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].reset()
+
+ def reorder_cache(self, beam_idx: ms.Tensor):
+ """Reorder the cache for beam search"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].reorder_cache(beam_idx)
+
+ def crop(self, max_length: int):
+ """Crop the cache to the given length"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].crop(max_length)
+
+ def batch_repeat_interleave(self, repeats: int):
+ """Repeat and interleave the cache"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
+
+ def batch_select_indices(self, indices: ms.Tensor):
+ """Select indices from the cache"""
+ for layer_idx in range(len(self.layers)):
+ self.layers[layer_idx].batch_select_indices(indices)
+
+ @property
+ def max_batch_size(self) -> int:
+ """Return the maximum batch size of the cache"""
+ values = [layer.max_batch_size for layer in self.layers]
+ if len(set(values)) > 1:
+ raise ValueError(f"Max batch size is not consistent across layers: {values}")
+ return values[0]
+
+ @property
+ def max_cache_len(self) -> int:
+ """Return the maximum cache length of the cache"""
+ values = [layer.max_cache_len for layer in self.layers]
+ return max(values)
+
+ @property
+ def is_compileable(self) -> bool:
+ """Return whether the cache is compileable"""
+ return all(layer.is_compileable for layer in self.layers)
+
+ @property
+ def is_sliding(self) -> list[bool]:
+ """Return whether the layers of the cache are sliding window"""
+ return [getattr(layer, "is_sliding", False) for layer in self.layers]
- def get_max_length(self) -> Optional[int]:
- """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
- return None
- def to_legacy_cache(self) -> Tuple[Tuple[ms.Tensor], Tuple[ms.Tensor]]:
- """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
- backward compatibility."""
+class DynamicCache(Cache):
+ """
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+ `[batch_size, num_heads, seq_len, head_dim]`.
+
+ See `Cache` for details on common methods that are implemented by all cache classes.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM, DynamicCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = DynamicCache()
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ DynamicCache()
+ ```
+ """
+
+ # Specialized constructor for DDP cache data, needed for BC
+ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[ms.Tensor, ms.Tensor]]] = None, *args, **kwargs):
+ super().__init__(layer_classes=DynamicLayer, *args, **kwargs)
+ # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212
+ # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
+ # iterable contains the key and value states for a layer gathered across replicas by torch.distributed
+ # (shape=[global batch size, num_heads, seq_len, head_dim]).
+ # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break
+ # compatibility. The name of the argument doesn't matter.
+ if ddp_cache_data is not None:
+ for key_states, value_states in ddp_cache_data:
+ self.layers.append(DynamicLayer.from_tensors(key_states, value_states))
+
+ def to_legacy_cache(self) -> tuple[tuple[ms.Tensor, ms.Tensor], ...]:
+ """
+ Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
+ backward compatibility.
+ """
legacy_cache = ()
- for layer_idx in range(len(self)):
- legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+ for layer in self.layers:
+ legacy_cache += ((layer.keys, layer.values),)
return legacy_cache
@classmethod
- def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None) -> "DynamicCache":
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
- backward compatibility."""
+ def from_legacy_cache(cls, past_key_values: tuple[tuple[ms.Tensor, ms.Tensor], ...]) -> "Cache":
+ """
+ Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
+ backward compatibility.
+ """
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
@@ -548,79 +1130,48 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] =
cache.update(key_states, value_states, layer_idx)
return cache
- def crop(self, max_length: int):
- """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
- negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
- # In case it is negative
- if max_length < 0:
- max_length = self.get_seq_length() - abs(max_length)
- if self.get_seq_length() <= max_length:
- return
+class StaticCache(Cache):
+ """
+ Static Cache class to be used with `mindspore.jit(model)`.
- self._seen_tokens = max_length
- for idx in range(len(self.key_cache)):
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
+ See `Cache` for details on common methods that are implemented by all cache classes.
- def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
- """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
- `_split_model_inputs()` in `generation.utils`"""
- out = []
- for i in range(0, full_batch_size, split_size):
- current_split = DynamicCache()
- current_split._seen_tokens = self._seen_tokens
- current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
- current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
- out.append(current_split)
- return out
+ Example:
- @classmethod
- def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
- """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
- `generation.utils`"""
- cache = cls()
- for idx in range(len(splits[0])):
- layer_keys = ops.cat([current.key_cache[idx] for current in splits], dim=0)
- layer_values = ops.cat([current.value_cache[idx] for current in splits], dim=0)
- cache.update(layer_keys, layer_values, idx)
- return cache
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM, StaticCache
- def batch_repeat_interleave(self, repeats: int):
- """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
- for layer_idx in range(len(self)):
- self.key_cache[layer_idx] = ops.repeat_interleave(self.key_cache[layer_idx], repeats, dim=0)
- self.value_cache[layer_idx] = ops.repeat_interleave(self.value_cache[layer_idx], repeats, dim=0)
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
- def batch_select_indices(self, indices: ms.Tensor):
- """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
- for layer_idx in range(len(self)):
- self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
- self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
+ >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
- def get_mask_sizes(self, cache_position: ms.Tensor, layer_idx: int) -> tuple[int, int]:
- """
- Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
- the given layer at `layer_idx`.
- The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
- for each layer.
- """
- query_length = cache_position.shape[0]
- past_seen_tokens = self.get_seq_length()
- kv_length = query_length + past_seen_tokens
- return kv_length, 0
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ StaticCache()
+ ```
+ """
+ def __init__(self, *args, **kwargs):
+ super().__init__(layer_classes=StaticLayer, *args, **kwargs)
-class SlidingWindowCache(StaticCache):
+
+class SlidingWindowCache(Cache):
"""
- Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
- Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
+ Sliding Window Cache class to be used with `mindspore.jit` for models like Mistral that support sliding window attention.
+ Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`,
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
- indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
+ indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
@@ -628,16 +1179,7 @@ class SlidingWindowCache(StaticCache):
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
- Parameters:
- config (`PretrainedConfig`):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- batch_size (`int`):
- The batch size with which the model will be used. Note that a new instance must be instantiated if a
- smaller batch size is used.
- max_cache_len (`int`):
- The maximum sequence length with which the model will be used.
- dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
- The default `dtype` to use when initializing the layer.
+ See `Cache` for details on common methods that are implemented by all cache classes.
Example:
@@ -652,92 +1194,54 @@ class SlidingWindowCache(StaticCache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
- >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype)
+ >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
SlidingWindowCache()
```
"""
- is_sliding = True
- is_compileable = True
-
- # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
- def __init__(
- self,
- config: PretrainedConfig,
- batch_size: int = None,
- max_cache_len: int = None,
- dtype: ms.Type = ms.float32,
- max_batch_size: Optional[int] = None,
- ) -> None:
- if not hasattr(config, "sliding_window") or config.sliding_window is None:
- raise ValueError(
- "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
- "sliding window attention, please check if there is a `sliding_window` field in the model "
- "config and it's not set to None."
- )
- max_cache_len = min(config.sliding_window, max_cache_len)
- super().__init__(
- config=config,
- batch_size=batch_size,
- max_cache_len=max_cache_len,
- dtype=dtype,
- max_batch_size=max_batch_size,
- )
+ def __init__(self, *args, **kwargs):
+ super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs)
- def update(
- self,
- key_states: ms.Tensor,
- value_states: ms.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[ms.Tensor]:
- cache_position = cache_kwargs.get("cache_position")
- k_out = self.key_cache[layer_idx]
- v_out = self.value_cache[layer_idx]
- key_states = key_states.to(k_out.dtype)
- value_states = value_states.to(v_out.dtype)
-
- # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
- if cache_position.shape[0] > self.max_cache_len:
- k_out = key_states[:, :, -self.max_cache_len :, :]
- v_out = value_states[:, :, -self.max_cache_len :, :]
- # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
- # we should return the whole states instead of k_out, v_out to take the whole prompt
- # into consideration when building kv cache instead of just throwing away tokens outside of the window
- return key_states, value_states
- slicing = ops.ones(self.max_cache_len, dtype=ms.int32).cumsum(0)
- cache_position = cache_position.clamp(0, self.max_cache_len - 1)
- to_shift = cache_position >= self.max_cache_len - 1
- indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
+class HybridCache(Cache):
+ """
+ Hybrid Cache class to be used with `mindspore.jit` for models that alternate between a local sliding window
+ attention and global attention in every other layer (originally implemented for Gemma2).
+ Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
+ for global attention. For more information, see the documentation of those layer types.
- k_out = k_out[:, :, indices]
- v_out = v_out[:, :, indices]
+ See `Cache` for details on common methods that are implemented by all cache classes.
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
+ Example:
- # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
- self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx])
- self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx])
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM, HybridCache
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
+ >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
- return k_out, v_out
+ >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
- def get_max_cache_shape(self) -> Optional[int]:
- return self.max_cache_len
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ HybridCache()
+ ```
+ """
- def reset(self):
- for layer_idx in range(len(self.key_cache)):
- # In-place ops prevent breaking the static address
- self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx])
- self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx])
+ def __init__(self, config: PretrainedConfig, *args, **kwargs):
+ if hasattr(config, "layer_types"):
+ layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
+ else:
+ # In this case, fall back to StaticCache
+ layer_classes = [StaticLayer] * config.num_hidden_layers
+ super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)
class EncoderDecoderCache(Cache):
@@ -745,10 +1249,12 @@ class EncoderDecoderCache(Cache):
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
cross-attention caches.
+ See `Cache` for details on common methods that are implemented by all cache classes.
+
Example:
```python
- >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
+ >>> from mindone.transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
>>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
>>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
@@ -766,27 +1272,43 @@ class EncoderDecoderCache(Cache):
"""
+ # Override @property from Cache
+ is_compileable = None
+
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
- super().__init__()
+ super().__init__(layer_classes=DynamicLayer)
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache
self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)
self.is_updated = {}
- for layer_idx in range(len(cross_attention_cache.key_cache)):
+ for layer_idx in range(len(cross_attention_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
- def __getitem__(self, layer_idx: int) -> List[Tuple[ms.Tensor]]:
+ def __iter__(self):
+ """
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+ keys and values
+ """
+ for layer_idx in range(len(self)):
+ yield (
+ self.self_attention_cache.layers[layer_idx].keys,
+ self.self_attention_cache.layers[layer_idx].values,
+ self.cross_attention_cache.layers[layer_idx].keys,
+ self.cross_attention_cache.layers[layer_idx].values,
+ )
+
+ def __getitem__(self, layer_idx: int) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
- self.self_attention_cache.key_cache[layer_idx],
- self.self_attention_cache.value_cache[layer_idx],
- self.cross_attention_cache.key_cache[layer_idx],
- self.cross_attention_cache.value_cache[layer_idx],
+ self.self_attention_cache.layers[layer_idx].keys,
+ self.self_attention_cache.layers[layer_idx].values,
+ self.cross_attention_cache.layers[layer_idx].keys,
+ self.cross_attention_cache.layers[layer_idx].values,
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
@@ -798,8 +1320,8 @@ def __len__(self):
"""
return len(self.self_attention_cache)
- def to_legacy_cache(self) -> Tuple[Tuple[ms.Tensor], Tuple[ms.Tensor]]:
- """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
+ def to_legacy_cache(self) -> tuple[tuple[ms.Tensor]]:
+ """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
legacy_cache = ()
if len(self.cross_attention_cache) > 0:
for self_attn, cross_attn in zip(
@@ -811,7 +1333,7 @@ def to_legacy_cache(self) -> Tuple[Tuple[ms.Tensor], Tuple[ms.Tensor]]:
return legacy_cache
@classmethod
- def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None) -> "EncoderDecoderCache":
+ def from_legacy_cache(cls, past_key_values: tuple[tuple[ms.Tensor, ms.Tensor], ...]) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(
self_attention_cache=DynamicCache(),
@@ -827,10 +1349,10 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] =
cache.is_updated[layer_idx] = True
return cache
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
- return self.self_attention_cache.get_seq_length(layer_idx)
+ # check if empty list because in case of static cache it will be a tensors and we can't check `if not ms.Tensor`
+ return self.self_attention_cache.get_seq_length(layer_idx, cache_position)
def reset(self):
if hasattr(self.self_attention_cache, "reset"):
@@ -863,14 +1385,18 @@ def check_dynamic_cache(self, method: str):
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
def crop(self, maximum_length: int):
- """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
- negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
+ """
+ Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
+ negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.
+ """
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)
- def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
- """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
- `_split_model_inputs()` in `generation.utils`"""
+ def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
+ """
+ Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
+ `_split_model_inputs()` in `generation.utils`
+ """
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
@@ -880,22 +1406,6 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDec
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out
- @classmethod
- def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
- """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
- `generation.utils`"""
- self_attention_cache = DynamicCache()
- cross_attention_cache = DynamicCache()
- for idx in range(len(splits[0])):
- layer_keys = ops.cat([current.self_attention_cache.key_cache[idx] for current in splits], axis=0)
- layer_values = ops.cat([current.self_attention_cache.value_cache[idx] for current in splits], axis=0)
- self_attention_cache.update(layer_keys, layer_values, idx)
-
- layer_keys = ops.cat([current.cross_attention_cache.key_cache[idx] for current in splits], axis=0)
- layer_values = ops.cat([current.cross_attention_cache.value_cache[idx] for current in splits], axis=0)
- cross_attention_cache.update(layer_keys, layer_values, idx)
- return cls(self_attention_cache, cross_attention_cache)
-
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
@@ -908,227 +1418,41 @@ def batch_select_indices(self, indices: ms.Tensor):
self.self_attention_cache.batch_select_indices(indices)
self.cross_attention_cache.batch_select_indices(indices)
+ def get_max_cache_shape(self) -> int:
+ """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
+ return self.self_attention_cache.get_max_cache_shape()
-class HybridCache(Cache):
- """
- Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
- and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
- and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
-
- Parameters:
- config (`PretrainedConfig):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- batch_size (`int`):
- The batch size with which the model will be used. Note that a new instance must be instantiated if a
- smaller batch size is used.
- max_cache_len (`int`):
- The maximum sequence length with which the model will be used.
- dtype (torch.dtype, *optional*, defaults to `torch.float32`):
- The default `dtype` to use when initializing the layer.
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
-
- >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
-
- >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
-
- >>> # Prepare a cache class and pass it to model's forward
- >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
- >>> max_generated_length = inputs.input_ids.shape[1] + 10
- >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, dtype=model.dtype)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- HybridCache()
- ```
- """
-
- # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
- # ALL changes from the PR that commented the line below when reactivating it.
- # is_compileable = True
-
- # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
- def __init__(
- self,
- config: PretrainedConfig,
- batch_size: int = None,
- max_cache_len: int = None,
- dtype: ms.Type = ms.float32,
- max_batch_size: Optional[int] = None,
- ) -> None:
- super().__init__()
- if batch_size is not None:
- logger.warning_once(
- f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
- "v4.49. Use the more precisely named 'max_batch_size' argument instead."
- )
- if not hasattr(config, "sliding_window") or config.sliding_window is None:
- raise ValueError(
- "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
- "sliding window attention, please check if there is a `sliding_window` field in the model "
- "config and it's not set to None."
- )
- self.max_cache_len = max_cache_len
- self.max_batch_size = batch_size or max_batch_size
- # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
- self.head_dim = (
- config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- )
-
- self.dtype = dtype
- self.num_key_value_heads = (
- config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
- )
-
- layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
- self.is_sliding = ms.tensor(
- [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=ms.bool_
- )
- self.key_cache: List[ms.Tensor] = []
- self.value_cache: List[ms.Tensor] = []
- global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
- sliding_cache_shape = (
- self.max_batch_size,
- self.num_key_value_heads,
- min(config.sliding_window, max_cache_len),
- self.head_dim,
- )
- for i in range(config.num_hidden_layers):
- # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
- # breaks when updating the cache.
- cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
- new_layer_key_cache = ops.zeros(cache_shape, dtype=self.dtype)
- new_layer_value_cache = ops.zeros(cache_shape, dtype=self.dtype)
- self.key_cache.append(new_layer_key_cache)
- self.value_cache.append(new_layer_value_cache)
-
- def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
- if cache_position.shape[0] > max_cache_len:
- k_out = key_states[:, :, -max_cache_len:, :]
- v_out = value_states[:, :, -max_cache_len:, :]
- # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
- # we should return the whole states instead of k_out, v_out to take the whole prompt
- # into consideration when building kv cache instead of just throwing away tokens outside of the window
- return key_states, value_states
-
- slicing = ops.ones(max_cache_len, dtype=ms.int32).cumsum(0)
- cache_position = cache_position.clamp(0, max_cache_len - 1)
- to_shift = cache_position >= max_cache_len - 1
- indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
- k_out = k_out[:, :, indices]
- v_out = v_out[:, :, indices]
-
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
- # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
- self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx])
- self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx])
-
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
- return k_out, v_out
-
- def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
-
- self.key_cache[layer_idx] = k_out
- self.value_cache[layer_idx] = v_out
- return k_out, v_out
-
- def update(
- self,
- key_states: ms.Tensor,
- value_states: ms.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[ms.Tensor]:
- cache_position = cache_kwargs.get("cache_position")
- sliding_window = cache_kwargs.get("sliding_window")
-
- k_out = self.key_cache[layer_idx]
- v_out = self.value_cache[layer_idx]
- key_states = key_states.to(k_out.dtype)
- value_states = value_states.to(v_out.dtype)
-
- if sliding_window:
- update_fn = self._sliding_update
- else:
- update_fn = self._static_update
-
- return update_fn(
- cache_position,
- layer_idx,
- key_states,
- value_states,
- k_out,
- v_out,
- k_out.shape[2],
- )
-
- def get_max_cache_shape(self) -> Optional[int]:
- return self.max_cache_len
-
- def get_seq_length(self, layer_idx: Optional[int] = 0):
- # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
- # limit the check to the first batch member and head dimension.
- # TODO: deprecate this function in favor of `cache_position`
- if layer_idx != 0:
- raise ValueError(
- "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
- "Using the `layer_idx` argument is not supported."
- )
- return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
-
- def reset(self):
- """Resets the cache values while preserving the objects"""
- for layer_idx in range(len(self.key_cache)):
- # In-place ops prevent breaking the static address
- self.key_cache[layer_idx] = mint.zeros_like(self.key_cache[layer_idx])
- self.value_cache[layer_idx] = mint.zeros_like(self.value_cache[layer_idx])
-
- @property
- def batch_size(self):
- logger.warning_once(
- f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
- "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
- )
- return self.max_batch_size
+ def get_mask_sizes(self, cache_position: ms.Tensor, layer_idx: int) -> tuple[int, int]:
+ return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
class MambaCache:
"""
+ Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed
+ in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead.
+
Cache for mamba model which does not have attention mechanism and key value states.
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
- batch_size (`int`):
- The batch size with which the model will be used. Note that a new instance must be instantiated if a
- smaller batch size is used.
- dtype (`mindspore.Type`, *optional*, defaults to `ms.float16`):
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
+ dtype (`ms.Type`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
Example:
```python
- >>> from transformers import AutoTokenizer
- >>> from mindone.transformers import MambaForCausalLM, MambaCache
+ >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
>>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
- >>> inputs = tokenizer(text="My name is Mamba", return_tensors="np")
- >>> for k,v in inputs.items():
- >>> inputs[k] = ms.tensor(v)
+ >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
- >>> past_key_values = MambaCache(config=model.config, batch_size=1, dtype=model.dtype)
+ >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values
MambaCache()
@@ -1137,36 +1461,29 @@ class MambaCache:
is_compileable = True
- # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__(
self,
- config: PretrainedConfig,
- batch_size: int = None,
+ config,
+ max_batch_size: int,
dtype: ms.Type = ms.float16,
- max_batch_size: Optional[int] = None,
):
- if batch_size is not None:
- logger.warning_once(
- f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
- "v4.49. Use the more precisely named 'max_batch_size' argument instead."
- )
- self.dtype = dtype
- self.max_batch_size = batch_size or max_batch_size
+ self.max_batch_size = max_batch_size
+ self._dtype = dtype
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
- self.conv_states: List[ms.Tensor] = []
- self.ssm_states: List[ms.Tensor] = []
+ self.conv_states: list[ms.Tensor] = []
+ self.ssm_states: list[ms.Tensor] = []
for _ in range(config.num_hidden_layers):
conv_state: ms.Tensor = mint.zeros(
(self.max_batch_size, self.intermediate_size, self.conv_kernel_size),
- dtype=dtype,
+ dtype=self._dtype,
)
ssm_state: ms.Tensor = mint.zeros(
(self.max_batch_size, self.intermediate_size, self.ssm_state_size),
- dtype=dtype,
+ dtype=self._dtype,
)
self.conv_states.append(conv_state)
@@ -1192,15 +1509,205 @@ def reset(self):
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()
- @property
- def batch_size(self):
- logger.warning_once(
- f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
- "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
- )
- return self.max_batch_size
-
class OffloadedStaticCache(StaticCache):
def __init__(self):
raise NotImplementedError
+
+
+def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]:
+ """
+ Parse processor arguments from kwargs based on the processor class init signature.
+
+ Args:
+ processor_class: The processor class to inspect, or None
+ kwargs: Dictionary of keyword arguments
+
+ Returns:
+ tuple: (processor_kwargs, remaining_kwargs)
+ """
+ try:
+ params = list(inspect.signature(processor_class.__init__).parameters)[2:]
+ except Exception:
+ return {}, kwargs
+
+ processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
+ remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
+ return processor_kwargs, remaining_kwargs
+
+
+def parse_layer_args_from_model_config(
+ config: Optional[PretrainedConfig],
+ batch_size: Optional[int] = None,
+ max_cache_len: Optional[int] = None,
+ dtype: Optional[ms.Type] = None,
+ tp_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
+) -> dict:
+ """
+ Parse layer arguments from model configuration for cache initialization.
+
+ Args:
+ config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info.
+ batch_size (`Optional[int]`): Batch size for cache initialization.
+ max_cache_len (`Optional[int]`): Maximum sequence length for cache.
+ dtype (`Optional[ms.Type]`): Data type for cache tensors.
+ tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads.
+ max_batch_size (`Optional[int]`): Maximum batch size for cache initialization.
+
+ Returns:
+ `dict`: Dictionary containing parsed layer arguments for cache initialization.
+ """
+ # No model config -> must be a dynamic cache, return bare dict
+ if config is None:
+ return {}
+ # Build the args dict for hybrid, sliding or static
+ else:
+ # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used)
+ if (
+ getattr(config, "layer_types", None) is not None
+ and "sliding_attention" in config.layer_types
+ and "full_attention" in config.layer_types
+ ):
+ if getattr(config, "sliding_window", None) is None:
+ raise ValueError(
+ "Setting up a hybrid or sliding window KVCache requires the model config supporting "
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
+ "config and it's not set to None."
+ )
+ # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window)
+ max_cache_len = max_cache_len or config.max_position_embeddings
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads:
+ head_dim = (
+ config.head_dim
+ if getattr(config, "head_dim", None) is not None
+ else config.hidden_size // config.num_attention_heads
+ )
+ num_heads = (
+ config.num_attention_heads
+ if getattr(config, "num_key_value_heads", None) is None
+ else config.num_key_value_heads
+ )
+ if tp_size is not None and tp_size > 1:
+ if num_heads % tp_size != 0:
+ raise ValueError(
+ f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}."
+ )
+ # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
+ num_heads //= tp_size
+ layer_args = {
+ "batch_size": max_batch_size if max_batch_size is not None else batch_size,
+ "max_cache_len": max_cache_len,
+ "dtype": dtype,
+ "head_dim": head_dim,
+ "num_heads": num_heads,
+ "sliding_window": getattr(config, "sliding_window", None),
+ }
+ return {k: v for k, v in layer_args.items() if v is not None}
+
+
+LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = {
+ "full_attention": StaticLayer,
+ "sliding_attention": SlidingWindowLayer,
+ "chunked_attention": ChunkedSlidingLayer,
+}
+PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = {
+ "offloaded": OffloadedCacheProcessor,
+ "quanto_quantized": QuantizedCacheProcessor,
+ "hqq_quantized": HQQQuantizedCacheProcessor,
+}
+
+
+class CacheConfig:
+ """
+ Base class for cache configs
+ """
+
+ cache_implementation: None
+
+ @classmethod
+ def from_dict(cls, config_dict, **kwargs):
+ """
+ Constructs a CacheConfig instance from a dictionary of parameters.
+ Args:
+ config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
+ **kwargs: Additional keyword arguments to override dictionary values.
+ Returns:
+ CacheConfig: Instance of CacheConfig constructed from the dictionary.
+ """
+ config = cls(**config_dict)
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+ return config
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default
+ `QuantizationConfig()` is serialized to JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ config_dict = self.to_dict()
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ writer.write(json_string)
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ return copy.deepcopy(self.__dict__)
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
+ def __iter__(self):
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
+ for attr, value in copy.deepcopy(self.__dict__).items():
+ yield attr, value
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ def to_json_string(self):
+ """
+ Serializes this instance to a JSON formatted string.
+ Returns:
+ str: JSON formatted string representing the configuration instance.
+ """
+ return json.dumps(self.__dict__, indent=2) + "\n"
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
+ returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # Remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
diff --git a/mindone/transformers/feature_extraction_utils.py b/mindone/transformers/feature_extraction_utils.py
index f93e8918e9..eff1cec771 100644
--- a/mindone/transformers/feature_extraction_utils.py
+++ b/mindone/transformers/feature_extraction_utils.py
@@ -23,7 +23,7 @@
import os
import warnings
from collections import UserDict
-from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
import numpy as np
from transformers.dynamic_module_utils import custom_object_save
@@ -40,15 +40,7 @@
logging,
)
-from .utils import (
- TensorType,
- add_model_info_to_auto_map,
- add_model_info_to_custom_pipelines,
- is_mindspore_available,
- is_mindspore_tensor,
- is_numpy_array,
- requires_backends,
-)
+from .utils import TensorType, is_mindspore_available, is_mindspore_tensor, is_numpy_array, requires_backends
if TYPE_CHECKING:
if is_mindspore_available():
@@ -59,6 +51,9 @@
PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821
+# type hinting: specifying the type of feature extractor class that inherits from FeatureExtractionMixin
+SpecificFeatureExtractorType = TypeVar("SpecificFeatureExtractorType", bound="FeatureExtractionMixin")
+
class BatchFeature(UserDict):
r"""
@@ -75,7 +70,7 @@ class BatchFeature(UserDict):
initialization.
"""
- def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
+ def __init__(self, data: Optional[dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type)
@@ -102,18 +97,6 @@ def __setstate__(self, state):
if "data" in state:
self.data = state["data"]
- # Copied from transformers.tokenization_utils_base.BatchEncoding.keys
- def keys(self):
- return self.data.keys()
-
- # Copied from transformers.tokenization_utils_base.BatchEncoding.values
- def values(self):
- return self.data.values()
-
- # Copied from transformers.tokenization_utils_base.BatchEncoding.items
- def items(self):
- return self.data.items()
-
def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
if tensor_type is None:
return None, None
@@ -191,7 +174,7 @@ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = Non
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Args:
- args (`Tuple`):
+ args (`tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
@@ -243,7 +226,7 @@ def _set_processor_class(self, processor_class: str):
@classmethod
def from_pretrained(
- cls,
+ cls: type[SpecificFeatureExtractorType],
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
@@ -251,7 +234,7 @@ def from_pretrained(
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
- ):
+ ) -> SpecificFeatureExtractorType:
r"""
Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
derived class of [`SequenceFeatureExtractor`].
@@ -276,12 +259,12 @@ def from_pretrained(
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
+ proxies (`dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
@@ -296,10 +279,10 @@ def from_pretrained(
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
If `False`, then this function returns just the final feature extractor object. If `True`, then this
- functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ functions returns a `tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
- kwargs (`Dict[str, Any]`, *optional*):
+ kwargs (`dict[str, Any]`, *optional*):
The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
controlled by the `return_unused_kwargs` keyword parameter.
@@ -365,7 +348,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
- kwargs (`Dict[str, Any]`, *optional*):
+ kwargs (`dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)
@@ -417,7 +400,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
@classmethod
def get_feature_extractor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
@@ -427,7 +410,7 @@ def get_feature_extractor_dict(
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
Returns:
- `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -488,13 +471,13 @@ def get_feature_extractor_dict(
user_agent=user_agent,
revision=revision,
)
- except EnvironmentError:
+ except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
- raise EnvironmentError(
+ raise OSError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
@@ -503,12 +486,12 @@ def get_feature_extractor_dict(
try:
# Load feature_extractor dict
- with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
+ with open(resolved_feature_extractor_file, encoding="utf-8") as reader:
text = reader.read()
feature_extractor_dict = json.loads(text)
except json.JSONDecodeError:
- raise EnvironmentError(
+ raise OSError(
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
)
@@ -519,30 +502,20 @@ def get_feature_extractor_dict(
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
)
- if not is_local:
- if "auto_map" in feature_extractor_dict:
- feature_extractor_dict["auto_map"] = add_model_info_to_auto_map(
- feature_extractor_dict["auto_map"], pretrained_model_name_or_path
- )
- if "custom_pipelines" in feature_extractor_dict:
- feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
- feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path
- )
-
return feature_extractor_dict, kwargs
@classmethod
- def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
+ def from_dict(cls, feature_extractor_dict: dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
"""
Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
parameters.
Args:
- feature_extractor_dict (`Dict[str, Any]`):
+ feature_extractor_dict (`dict[str, Any]`):
Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the
[`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
- kwargs (`Dict[str, Any]`):
+ kwargs (`dict[str, Any]`):
Additional parameters from which to initialize the feature extractor object.
Returns:
@@ -568,10 +541,10 @@ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrain
else:
return feature_extractor
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
@@ -595,7 +568,7 @@ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeature
A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
object instantiated from that JSON file.
"""
- with open(json_file, "r", encoding="utf-8") as reader:
+ with open(json_file, encoding="utf-8") as reader:
text = reader.read()
feature_extractor_dict = json.loads(text)
return cls(**feature_extractor_dict)
@@ -641,11 +614,6 @@ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoFeatureExtractor`.
-
-
- This API is experimental and may have some slight breaking changes in the next releases.
-
-
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
diff --git a/mindone/transformers/generation/beam_search.py b/mindone/transformers/generation/beam_search.py
index 0cb3aad2ea..01a2c02932 100644
--- a/mindone/transformers/generation/beam_search.py
+++ b/mindone/transformers/generation/beam_search.py
@@ -18,7 +18,7 @@
from abc import ABC, abstractmethod
from collections import UserDict
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Optional, Union
import numpy as np
from transformers.generation.beam_constraints import Constraint, ConstraintListState
@@ -44,7 +44,7 @@
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
- eos_token_id (`Union[int, List[int]]`, *optional*):
+ eos_token_id (`Union[int, list[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`ms.Tensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
@@ -80,7 +80,7 @@
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
- eos_token_id (`Union[int, List[int]]`, *optional*):
+ eos_token_id (`Union[int, list[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
Return:
@@ -106,7 +106,7 @@ def process(
next_tokens: ms.Tensor,
next_indices: ms.Tensor,
**kwargs,
- ) -> Tuple[ms.Tensor]:
+ ) -> tuple[ms.Tensor]:
raise NotImplementedError("This is an abstract method.")
@abstractmethod
@@ -154,7 +154,7 @@ class BeamSearchScorer(BeamScorer):
[`~transformers.BeamSearchScorer.finalize`].
num_beam_groups (`int`, *optional*, defaults to 1):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
- See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
+ See [this paper](https://huggingface.co/papers/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
@@ -215,11 +215,11 @@ def process(
next_tokens: ms.Tensor,
next_indices: ms.Tensor,
pad_token_id: Optional[Union[int, ms.Tensor]] = None,
- eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None,
+ eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None,
beam_indices: Optional[ms.Tensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
- ) -> Dict[str, ms.Tensor]:
+ ) -> dict[str, ms.Tensor]:
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups
@@ -320,10 +320,10 @@ def finalize(
final_beam_indices: ms.Tensor,
max_length: int,
pad_token_id: Optional[Union[int, ms.Tensor]] = None,
- eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None,
+ eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None,
beam_indices: Optional[ms.Tensor] = None,
decoder_prompt_len: Optional[int] = 0,
- ) -> Tuple[ms.Tensor]:
+ ) -> tuple[ms.Tensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups
if eos_token_id is not None and not isinstance(eos_token_id, ms.Tensor):
@@ -421,7 +421,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
num_beams (`int`):
Number of beams for beam search.
- constraints (`List[Constraint]`):
+ constraints (`list[Constraint]`):
A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
output. For more information, the documentation of [`Constraint`] should be read.
length_penalty (`float`, *optional*, defaults to 1.0):
@@ -449,7 +449,7 @@ def __init__(
self,
batch_size: int,
num_beams: int,
- constraints: List[Constraint],
+ constraints: list[Constraint],
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
@@ -508,10 +508,10 @@ def process(
next_indices: ms.Tensor,
scores_for_all_vocab: ms.Tensor,
pad_token_id: Optional[Union[int, ms.Tensor]] = None,
- eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None,
+ eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None,
beam_indices: Optional[ms.Tensor] = None,
decoder_prompt_len: Optional[int] = 0,
- ) -> Tuple[ms.Tensor]:
+ ) -> tuple[ms.Tensor]:
r"""
Args:
input_ids (`ms.Tensor` of shape `(batch_size * num_beams, sequence_length)`):
@@ -531,7 +531,7 @@ def process(
The scores of all tokens in the vocabulary for each of the beam hypotheses.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
- eos_token_id (`Union[int, List[int]]`, *optional*):
+ eos_token_id (`Union[int, list[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`ms.Tensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
@@ -807,10 +807,10 @@ def finalize(
final_beam_indices: ms.Tensor,
max_length: int,
pad_token_id: Optional[Union[int, ms.Tensor]] = None,
- eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None,
+ eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None,
beam_indices: Optional[ms.Tensor] = None,
decoder_prompt_len: Optional[int] = 0,
- ) -> Tuple[ms.Tensor]:
+ ) -> tuple[ms.Tensor]:
batch_size = len(self._beam_hyps)
if eos_token_id is not None and not isinstance(eos_token_id, ms.Tensor):
diff --git a/mindone/transformers/generation/candidate_generator.py b/mindone/transformers/generation/candidate_generator.py
index f5c554ca07..0f2a4d380c 100644
--- a/mindone/transformers/generation/candidate_generator.py
+++ b/mindone/transformers/generation/candidate_generator.py
@@ -24,10 +24,10 @@
from transformers import is_sklearn_available
import mindspore as ms
-from mindspore import mint
+from mindspore import mint, nn
from mindspore import numpy as mnp
-from ..cache_utils import DynamicCache
+from ..mindspore_utils import prune_linear_layer
if is_sklearn_available():
from sklearn.metrics import roc_curve
@@ -286,8 +286,8 @@ def _update_past_and_masks(self, input_ids: ms.Tensor, remove_from_pkv: int = 0,
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
- self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
- self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens
+ self.assistant_kwargs["past_key_values"] = self.assistant_kwargs["past_key_values"].crop(
+ new_cache_size - num_added_tokens
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
@@ -502,7 +502,6 @@ def get_candidates(self, input_ids: ms.Tensor) -> tuple[ms.Tensor, Optional[ms.T
if max_new_tokens == 0:
return input_ids, None
- input_ids = input_ids
remove_from_pkv = 0
assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
@@ -604,12 +603,69 @@ def _process_assistant_outputs(
return new_target_ids
+class _PruneReindexingLMHead(nn.Cell):
+ """
+ A class to prune and reindex the language model head.
+
+ This class prunes the language model head to only include the specified token IDs and reindexes the logits
+ to map back to the original vocabulary.
+
+ Args:
+ original_lm_head (nn.Module): The original language model head.
+ token_ids (list[int]): The list of token IDs to keep.
+ """
+
+ def __init__(self, original_lm_head, assistant_overlap_token_ids):
+ super().__init__()
+ self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to(
+ original_lm_head.weight.dtype
+ )
+
+ def construct(self, hidden_states):
+ pruned_logits = self.pruned_lm_head(hidden_states)
+ return pruned_logits
+
+
+class _MapInputEmbedding(nn.Cell):
+ def __init__(self, original_embedding: mint.nn.Embedding, assistant_overlap_token_ids):
+ """
+ Wraps an existing embedding layer and remaps token IDs before lookup.
+
+ Args:
+ original_embedding (mint.nn.Embedding): Pre-trained or existing embedding layer.
+ assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs.
+ Example: {old_id: new_id}
+ """
+ super().__init__()
+ self.original_embedding = original_embedding
+ self.weight = original_embedding.weight
+ self.assistant_overlap_token_ids = assistant_overlap_token_ids
+ self.map = False
+
+ def construct(self, input_ids: ms.Tensor) -> ms.Tensor:
+ """
+ Args:
+ input_ids (ms.Tensor): Tensor of token IDs (batch_size, seq_len).
+
+ Returns:
+ ms.Tensor: Corresponding input embeddings.
+ """
+ if self.map:
+ # Get the last item from input_ids
+ my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0)
+ else:
+ self.map = True
+ my_input_ids = input_ids
+
+ return self.original_embedding(my_input_ids)
+
+
class AssistantToTargetTranslator:
"""
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
- (https://www.arxiv.org/abs/2502.05202).
+ (https://huggingface.co/papers/2502.05202).
It maintains mappings between the two vocabularies and handles token/logit conversion.
Args:
@@ -617,8 +673,12 @@ class AssistantToTargetTranslator:
The tokenizer used by the target (main) model.
assistant_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the assistant model.
- target_vocab_size (`int`, *optional*):
+ target_vocab_size (`int`):
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
+ assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility.
+ assistant_prune_lm_head (bool): Whether to prune the assistant model's language model
+ head to match the target vocabulary. This is only applicable if `assistant_model` is provided.
+ Defaults to False for backward compatibility.
"""
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
@@ -629,9 +689,11 @@ def __init__(
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
+ assistant_model: Optional["PreTrainedModel"] = None,
+ assistant_prune_lm_head: bool = False,
):
- self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
- self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
+ self._target_tokenizer: PreTrainedTokenizerBase = target_tokenizer
+ self._assistant_tokenizer: PreTrainedTokenizerBase = assistant_tokenizer
self.target_vocab_size: int = target_vocab_size
(
self._assistant_to_target_input_ids,
@@ -639,11 +701,39 @@ def __init__(
) = self._get_assistant_to_target_input_ids()
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
self.logits_processors: Optional[LogitsProcessorList] = None
+ self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None
if len(self._suppress_input_ids) > 0:
- # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
- self.logits_processors = LogitsProcessorList(
- [SuppressTokensLogitsProcessor(self._get_suppress_input_ids())]
- )
+ # the assistant vocab is not a subset of the target vocab
+ if self.assistant_prune_lm_head:
+ self.assistant_overlap_token_ids = ms.tensor(
+ list(self.target_to_assistant_input_ids.values()),
+ dtype=ms.int64,
+ )
+ original_lm_head = assistant_model.get_output_embeddings()
+ pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids)
+ del original_lm_head
+ assistant_model.set_output_embeddings(pruned_lm_head)
+
+ original_input_embeddings = assistant_model.get_input_embeddings()
+ map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids)
+ del original_input_embeddings
+ assistant_model.set_input_embeddings(map_input_embeddings)
+ self.map_input_embeddings = map_input_embeddings
+ else:
+ self.logits_processors = LogitsProcessorList(
+ [SuppressTokensLogitsProcessor(self._get_suppress_input_ids())]
+ )
+
+ def unmap_input_ids(self):
+ """
+ Disables the mapping of input ids despite the assistant pruning for the language model head being enabled.
+
+ This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space.
+ By disabling the mapping, it ensures that the input ids are processed correctly without remapping.
+
+ """
+ if self.assistant_prune_lm_head:
+ self.map_input_embeddings.map = False
def _get_assistant_to_target_input_ids(self):
target_vocab = self._target_tokenizer.get_vocab()
@@ -671,14 +761,14 @@ def _get_assistant_to_target_input_ids(self):
}
max_assistant_index = max(assistant_vocab.values())
- assistant_to_target_input_ids = mint.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
+ assistant_to_target_input_ids = mint.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=ms.int32)
target_to_assistant_input_ids: dict[int, int] = {}
for tok, assistant_id in assistant_vocab.items():
target_id = target_vocab.get(tok)
if target_id is not None:
assistant_to_target_input_ids[assistant_id] = target_id
target_to_assistant_input_ids[target_id] = assistant_id
- return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids
+ return assistant_to_target_input_ids, target_to_assistant_input_ids
def _get_suppress_input_ids(self) -> list[int]:
"""
@@ -697,7 +787,12 @@ def get_target_ids(self, assistant_input_ids, target_input_ids, assistant_candid
if num_new_tokens == 0:
return target_input_ids
else:
- transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
+ # Get last `num_new_tokens` candidate IDs
+ last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:]
+ if self.assistant_prune_lm_head:
+ # Map assistant IDs -> target input IDs
+ last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids]
+ transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids]
return mint.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
def get_target_logits(self, assistant_logits: ms.Tensor) -> ms.Tensor:
@@ -713,8 +808,11 @@ def get_target_logits(self, assistant_logits: ms.Tensor) -> ms.Tensor:
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
- target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
-
+ if self.assistant_prune_lm_head:
+ target_logits[..., target_logits_supported_indices] = assistant_logits
+ else:
+ valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
+ target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
return target_logits
@@ -732,7 +830,8 @@ def get_translator(
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
target_vocab_size: int,
- assistant_model_device: str = "cpu",
+ assistant_model: Optional["PreTrainedModel"] = None,
+ assistant_prune_lm_head: bool = False,
) -> AssistantToTargetTranslator:
assistant_dict = cls._cache.get(target_tokenizer)
if assistant_dict is None:
@@ -742,7 +841,7 @@ def get_translator(
mapping = assistant_dict.get(assistant_tokenizer)
if mapping is None:
mapping = AssistantToTargetTranslator(
- target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
+ target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model, assistant_prune_lm_head
)
assistant_dict[assistant_tokenizer] = mapping
@@ -879,7 +978,7 @@ def _prepare_assistant_input_ids(self, target_input_ids: ms.Tensor) -> ms.Tensor
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
assistant_input_ids = mint.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
assistant_input_ids = assistant_input_ids.to(dtype=ms.int64)
-
+ self._atm_translator.unmap_input_ids()
return assistant_input_ids, len(assistant_new_ids[0])
@@ -925,7 +1024,7 @@ def get_candidates(self, input_ids: ms.Tensor) -> tuple[ms.Tensor, Optional[ms.T
Return:
`ms.Tensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
"""
- input_length = input_ids.size(1)
+ input_length = input_ids.shape[1]
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
@@ -1051,47 +1150,6 @@ def get_candidates(self, input_ids: ms.Tensor) -> tuple[ms.Tensor, Optional[ms.T
return candidate_ids, candidate_logits
-def _crop_past_key_values(model, past_key_values, max_length):
- """Crops the past key values up to a certain maximum length."""
- new_past = []
- if model.config.is_encoder_decoder:
- for idx in range(len(past_key_values)):
- new_past.append(
- (
- past_key_values[idx][0][:, :, :max_length, :],
- past_key_values[idx][1][:, :, :max_length, :],
- past_key_values[idx][2],
- past_key_values[idx][3],
- )
- )
- past_key_values = tuple(new_past)
- # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
- elif "gptbigcode" in model.__class__.__name__.lower() or (
- model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
- ):
- if model.config.multi_query:
- for idx in range(len(past_key_values)):
- past_key_values[idx] = past_key_values[idx][:, :max_length, :]
- else:
- for idx in range(len(past_key_values)):
- past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
- elif isinstance(past_key_values, DynamicCache):
- past_key_values.crop(max_length)
- elif past_key_values is not None:
- for idx in range(len(past_key_values)):
- if past_key_values[idx] != ([], []):
- new_past.append(
- (
- past_key_values[idx][0][:, :, :max_length, :],
- past_key_values[idx][1][:, :, :max_length, :],
- )
- )
- else:
- new_past.append((past_key_values[idx][0], past_key_values[idx][1]))
- past_key_values = tuple(new_past)
- return past_key_values
-
-
def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool) -> dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
diff --git a/mindone/transformers/generation/logits_process.py b/mindone/transformers/generation/logits_process.py
index 66e002d7ff..036cb6238d 100644
--- a/mindone/transformers/generation/logits_process.py
+++ b/mindone/transformers/generation/logits_process.py
@@ -18,7 +18,7 @@
import inspect
import math
-from typing import Callable, Iterable, List, Optional, Union
+from typing import Callable, Iterable, Optional, Union
import numpy as np
from transformers.utils import add_start_docstrings
@@ -62,18 +62,6 @@ def __call__(
)
-class LogitsWarper:
- """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
-
- @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
- def __call__(
- self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
- ) -> Union[ms.Tensor, np.ndarray]:
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
- )
-
-
class LogitsProcessorList(list):
"""
This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
@@ -91,7 +79,7 @@ def __call__(
scores (`Union[ms.Tensor, np.ndarray]` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
- kwargs (`Dict[str, Any]`, *optional*):
+ kwargs (`dict[str, Any]`, *optional*):
Additional kwargs that are specific to a logits processor.
Return:
@@ -122,11 +110,42 @@ class MinLengthLogitsProcessor(LogitsProcessor):
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
- eos_token_id (`Union[int, List[int], ms.Tensor, np.ndarray]`):
+ eos_token_id (`Union[int, list[int], ms.Tensor, np.ndarray]`):
The id(s) of the *end-of-sequence* token.
+
+ Examples:
+
+ ```python
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
+ >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
+
+ >>> inputs = tokenizer("A number:", return_tensors="np")
+
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
+
+ >>> gen_out = model.generate(**inputs)
+ >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
+ A number: one
+
+ >>> # setting `min_length` to a value smaller than the uncontrolled output length has no impact
+ >>> gen_out = model.generate(**inputs, min_length=3)
+ >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
+ A number: one
+
+ >>> # setting a larger `min_length` will force the model to generate beyond its natural ending point, which is not
+ >>> # necessarily incorrect
+ >>> gen_out = model.generate(**inputs, min_length=10)
+ >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
+ A number: one thousand, nine hundred and ninety-four
+ ```
"""
- def __init__(self, min_length: int, eos_token_id: Union[int, List[int], ms.Tensor, np.ndarray], **ignore):
+ def __init__(self, min_length: int, eos_token_id: Union[int, list[int], ms.Tensor, np.ndarray], **ignore):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
@@ -175,12 +194,36 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
input length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
- eos_token_id (`Union[int, List[int], ms.Tensor]`):
+ eos_token_id (`Union[int, list[int], ms.Tensor]`):
The id(s) of the *end-of-sequence* token.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
+ >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
+
+ >>> inputs = tokenizer(["A number:"], return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
+
+ >>> gen_out = model.generate(**inputs)
+ >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
+ A number: one
+
+ >>> # setting `min_new_tokens` will force the model to generate beyond its natural ending point, which is not
+ >>> # necessarily incorrect
+ >>> gen_out = model.generate(**inputs, min_new_tokens=2)
+ >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
+ A number: one thousand
+ ```
"""
def __init__(
- self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], ms.Tensor], **ignore
+ self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, list[int], ms.Tensor], **ignore
):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
@@ -244,6 +287,38 @@ class TemperatureLogitsWarper(LogitsProcessor):
Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
token.
+
+ Examples:
+
+ ```python
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer, set_seed
+ >>> from mindone.transformers import AutoModelForCausalLM
+
+ >>> set_seed(0) # for reproducibility
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
+ >>> model.config.pad_token_id = model.config.eos_token_id
+ >>> inputs = tokenizer(["Hugging Face Company is"], return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
+
+ >>> # With temperature=1.0, the default, we consistently get random outputs due to random sampling.
+ >>> generate_kwargs = {"max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "num_return_sequences": 2}
+ >>> outputs = model.generate(**inputs, **generate_kwargs)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ ['Hugging Face Company is one of these companies that is going to take a',
+ "Hugging Face Company is a brand created by Brian A. O'Neil"]
+
+ >>> # However, with temperature close to 0, it approximates greedy decoding strategies (invariant)
+ >>> generate_kwargs["temperature"] = 0.0001
+ >>> outputs = model.generate(**inputs, **generate_kwargs)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
+ ['Hugging Face Company is a company that has been around for over 20 years',
+ 'Hugging Face Company is a company that has been around for over 20 years']
+ ```
+
"""
def __init__(self, temperature: float):
@@ -269,9 +344,10 @@ def __call__(
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
- most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
+ most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt
+ by default.
- In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
+ In the original [paper](https://huggingface.co/papers/papers/1909.05858), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
@@ -280,21 +356,102 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
+ prompt_ignore_length (`int`, *optional*):
+ The original input ids sequence length, which if provided, will not be used in the penalty calculation.
+ Examples:
+
+ ```py
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor
+
+ >>> # Initializing the model and tokenizer for it
+ >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
+ >>> inputs = tokenizer(["I'm not going to"], return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
+
+ >>> # This shows a normal generate without any specific parameters
+ >>> summary_ids = model.generate(**inputs)
+ >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
+ I'm not going to be able to do that. I'm going to be able to do that
+
+ >>> # This generates a penalty for repeated tokens
+ >>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
+ >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
+ I'm not going to be able to do that. I'll just have to go out and play
+
+ >>> # We can also exclude the input prompt by creating an instance of this class
+ >>> # with a `prompt_ignore_length` and passing it as a custom logit processor
+ >>> rep_pen_processor = RepetitionPenaltyLogitsProcessor(
+ ... penalty=1.1,
+ ... prompt_ignore_length=inputs["input_ids"].shape[-1]
+ ... )
+ >>> penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
+ >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
+ I'm not going to be able to do that. I'm going to have to go through a lot of things, and
+ ```
+
"""
- def __init__(self, penalty: float):
+ def __init__(self, penalty: float, prompt_ignore_length: Optional[int] = None):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
+ if prompt_ignore_length is not None and (not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0):
+ raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")
+
self.penalty = penalty
+ self.prompt_ignore_length = prompt_ignore_length
+ self.logits_indices = None
+ self.cumulative_seqlens_q = None
+
+ def set_continuous_batching_context(self, logits_indices: ms.Tensor, cumulative_seqlens_q: ms.Tensor):
+ self.logits_indices = logits_indices
+ self.cumulative_seqlens_q = cumulative_seqlens_q
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor:
- score = mint.gather(scores, 1, input_ids)
+ if self.prompt_ignore_length:
+ input_ids = input_ids[:, self.prompt_ignore_length :]
+
+ if scores.dim() == 3:
+ if self.logits_indices is not None and self.cumulative_seqlens_q is not None:
+ batch_size, seq_len, vocab_size = scores.shape
+ last_positions = self.logits_indices
+ last_scores = scores[0, last_positions, :]
+
+ # Prepare token mask
+ token_mask = mint.zeros_like(last_scores, dtype=ms.bool_)
+ cu_seq_lens = self.cumulative_seqlens_q
+ lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
+ seq_indices = mint.repeat_interleave(mint.arange(len(lengths)), lengths)
+ token_mask[seq_indices, input_ids] = True
+
+ # Apply penalty
+ penalty_scores = mint.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
+ scores[0, last_positions, :] = mint.where(token_mask, penalty_scores, last_scores)
+ else:
+ batch_size, seq_len, vocab_size = scores.shape
+ last_scores = scores[:, -1, :]
+ token_mask = mint.zeros_like(last_scores, dtype=ms.bool_)
+ if input_ids.dim() == 1:
+ unique_tokens = mint.unique(input_ids)
+ token_mask.scatter_(1, unique_tokens.unsqueeze(0), True)
+ else:
+ token_mask.scatter_(1, input_ids, True)
+ # if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities
+ penalty_scores = mint.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
+ scores[:, -1, :] = mint.where(token_mask, penalty_scores, last_scores)
+ return scores
+ if input_ids.dim() == 1:
+ input_ids = input_ids.unsqueeze(1)
+
+ score = mint.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = mint.where(score < 0, score * self.penalty, score / self.penalty)
-
scores_processed = scores.scatter(1, input_ids, score)
return scores_processed
@@ -354,7 +511,7 @@ def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor:
return scores_processed
-class TopPLogitsWarper(LogitsWarper):
+class TopPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
@@ -449,6 +606,33 @@ class TopKLogitsWarper(LogitsProcessor):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
+
+ Examples:
+
+ ```python
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer, set_seed
+ >>> from mindone.transformers import AutoModelForCausalLM
+
+ >>> set_seed(1)
+ >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
+
+ >>> inputs = tokenizer("A sequence: A, B, C, D", return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
+
+ >>> # With sampling, the output is unexpected -- sometimes too unexpected.
+ >>> outputs = model.generate(**inputs, do_sample=True)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
+ A sequence: A, B, C, D, E — S — O, P — R
+
+ >>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
+ >>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
+ >>> outputs = model.generate(**inputs, do_sample=True, top_k=2)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
+ A sequence: A, B, C, D, E, F, G, H, I
+ ```
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@@ -477,7 +661,7 @@ def __call__(
class MinPLogitsWarper(LogitsProcessor):
"""
[`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
- probability of the most likely token. As a result, the filter becomes more agressive in the presence of
+ probability of the most likely token. As a result, the filter becomes more aggressive in the presence of
high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
Often used together with [`TemperatureLogitsWarper`]. Used as an alternative to [`TopPLogitsWarper`] and
@@ -499,13 +683,17 @@ class MinPLogitsWarper(LogitsProcessor):
Examples:
```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer, set_seed
+ >>> from mindone.transformers import AutoModelForCausalLM
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
- >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
+ >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
@@ -558,7 +746,7 @@ class TypicalLogitsWarper(LogitsProcessor):
whose log probability is close to the entropy of the token probability distribution. This means that the most
likely tokens may be discarded in the process.
- See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
+ See [Typical Decoding for Natural Language Generation](https://huggingface.co/papers/2202.00666) for more information.
Args:
mass (`float`, *optional*, defaults to 0.9):
@@ -571,12 +759,16 @@ class TypicalLogitsWarper(LogitsProcessor):
Examples:
```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer, set_seed
+ >>> from mindone.transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
>>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
- >>> inputs = tokenizer("1, 2, 3", return_tensors="pt")
+ >>> inputs = tokenizer("1, 2, 3", return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
>>> # We can see that greedy decoding produces a sequence of numbers
>>> outputs = model.generate(**inputs)
@@ -644,7 +836,7 @@ class EpsilonLogitsWarper(LogitsProcessor):
r"""
[`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
- Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
+ Desmoothing](https://huggingface.co/papers/2210.15191) for more information.
Args:
epsilon (`float`):
@@ -656,13 +848,17 @@ class EpsilonLogitsWarper(LogitsProcessor):
Examples:
```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer, set_seed
+ >>> from mindone.transformers import AutoModelForCausalLM
>>> set_seed(1)
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
- >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")
+ >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
>>> # With sampling, the output is unexpected -- sometimes too unexpected.
>>> outputs = model.generate(**inputs, do_sample=True)
@@ -673,7 +869,7 @@ class EpsilonLogitsWarper(LogitsProcessor):
>>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to
>>> # Top P sampling, which restricts tokens based on their cumulative probability.
- >>> # Pro tip: The paper recomends using `epsilon_cutoff` values between 3e-4 and 9e-4
+ >>> # Pro tip: The paper recommends using `epsilon_cutoff` values between 3e-4 and 9e-4
>>> outputs = model.generate(**inputs, do_sample=True, epsilon_cutoff=0.1)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
@@ -714,7 +910,7 @@ class EtaLogitsWarper(LogitsProcessor):
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
- Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
+ Sampling as Language Model Desmoothing](https://huggingface.co/papers/2210.15191) for more information. Note: `do_sample`
must be set to `True` for this `LogitsProcessor` to work.
@@ -840,7 +1036,7 @@ def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
def _calc_banned_ngram_tokens(
ngram_size: int, prev_input_ids: ms.Tensor, num_hypos: int, cur_len: int
-) -> List[Iterable[int]]:
+) -> list[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
@@ -867,7 +1063,7 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York
might lead to undesirable outcomes where the city's name appears only once in the entire text.
- [Reference](https://huggingface.co/blog/how-to-generate)
+ [Reference](https://huggingface.co/papers/blog/how-to-generate)
@@ -990,12 +1186,12 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
At a token-level, biasing a word is different from biasing a word with a space before it. If you want to bias
"foo" mid-sentence, you'll likely want to add a prefix space and bias " foo" instead. Check the tokenizer section
- of our NLP course to find out why: https://huggingface.co/learn/nlp-course/chapter2/4?fw=pt
+ of our NLP course to find out why: https://huggingface.co/papers/learn/nlp-course/chapter2/4?fw=pt
Args:
- sequence_bias (`List[List[Union[List[int], float]]]`):
+ sequence_bias (`list[list[Union[list[int], float]]]`):
List of lists that maps a sequence of tokens to its bias term (e.g. `[[[10, 45], -2.0],
[[64], -7.5]]`). Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
@@ -1005,7 +1201,8 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
Examples:
```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
@@ -1044,13 +1241,13 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
```
"""
- def __init__(self, sequence_bias: List[List[Union[List[int], float]]]):
+ def __init__(self, sequence_bias: list[list[Union[list[int], float]]]):
self.sequence_bias = sequence_bias
self._validate_arguments()
self._convert_list_arguments_into_dict()
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
- # is infered in the first usage, which inhibits initializing here)
+ # is inferred in the first usage, which inhibits initializing here)
self.length_1_bias = None
self.prepared_bias_variables = False
@@ -1106,9 +1303,13 @@ def _prepare_bias_variables(self, scores: ms.Tensor):
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
# with simpler logic.
self.length_1_bias = mint.zeros((vocabulary_size,), dtype=ms.float32)
+ # Extract single-token sequences and their biases
+ single_token_ids = []
+ single_token_biases = []
for sequence_ids, bias in self.sequence_bias.items():
if len(sequence_ids) == 1:
- self.length_1_bias[sequence_ids[-1]] = bias
+ single_token_ids.append(sequence_ids[0])
+ single_token_biases.append(bias)
self.prepared_bias_variables = True
@@ -1166,14 +1367,14 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
- [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
+ [here](https://huggingface.co/papers/docs/tokenizers/api/pre-tokenizers).
Args:
- bad_words_ids (`List[List[int]]`):
+ bad_words_ids (`list[list[int]]`):
List of list of token ids that are not allowed to be generated.
- eos_token_id (`Union[int, List[int], ms.Tensor]`, *optional*):
+ eos_token_id (`Union[int, list[int], ms.Tensor]`, *optional*):
The id(s) of the *end-of-sequence* token.
Examples:
@@ -1211,7 +1412,7 @@ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
```
"""
- def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[int, List[int], ms.Tensor]] = None):
+ def __init__(self, bad_words_ids: list[list[int]], eos_token_id: Optional[Union[int, list[int], ms.Tensor]] = None):
self.bad_word_ids = bad_words_ids
self._validate_arguments()
@@ -1222,8 +1423,9 @@ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Optional[Union[
eos_token_id = [eos_token_id]
eos_token_id = ms.tensor(eos_token_id)
+ eos_token_id_list = eos_token_id.tolist() # convert to python list before
bad_words_ids = list(
- filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
+ filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id_list), bad_words_ids)
)
# Forbidding a sequence is equivalent to setting its bias to -inf
@@ -1248,17 +1450,55 @@ def _validate_arguments(self):
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
- generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
+ generation. See [Autoregressive Entity Retrieval](https://huggingface.co/papers/2010.00904) for more information.
Args:
- prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`):
+ prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], list[int]]`):
This function constraints the beam search to allowed tokens only at each step. This function takes 2
arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
`batch_id`.
+
+ Examples:
+
+ ```py
+ >>> import mindspore as ms
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM
+
+ >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
+ >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
+
+ >>> inputs = tokenizer("Alice and Bob", return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
+
+ >>> # By default, it continues generating according to the model's logits
+ >>> outputs = model.generate(**inputs, max_new_tokens=5)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
+ Alice and Bob are friends
+
+ >>> # We can constrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
+ >>> # For instance, we can force an entire entity to be generated when its beginning is detected.
+ >>> entity = ms.tensor(tokenizer(" Bob Marley", return_tensors="np").input_ids[0]) # 3 tokens
+ >>> def prefix_allowed_tokens_fn(batch_id, input_ids):
+ ... '''
+ ... Attempts to generate 'Bob Marley' when 'Bob' is detected.
+ ... In this case, `batch_id` is not used, but you can set rules for each batch member.
+ ... '''
+ ... if input_ids[-1] == entity[0]:
+ ... return [entity[1].item()]
+ ... elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
+ ... return [entity[2].item()]
+ ... return list(range(tokenizer.vocab_size)) # If no match, allow all tokens
+
+ >>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
+ Alice and Bob Marley
+ ```
"""
- def __init__(self, prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], List[int]], num_beams: int):
+ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], list[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
@@ -1326,7 +1566,8 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
Examples:
```python
- >>> from mindone.transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForSeq2SeqLM
>>> # Initialize the model and tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
@@ -1491,7 +1732,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
- eos_token_id (`Union[int, List[int], ms.Tensor]`):
+ eos_token_id (`Union[int, list[int], ms.Tensor]`):
The id(s) of the *end-of-sequence* token.
Examples:
@@ -1516,7 +1757,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
```
"""
- def __init__(self, max_length: int, eos_token_id: Union[int, List[int], ms.Tensor]):
+ def __init__(self, max_length: int, eos_token_id: Union[int, list[int], ms.Tensor]):
self.max_length = max_length
if not isinstance(eos_token_id, ms.Tensor):
@@ -1571,7 +1812,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
exponential_decay_length_penalty (`tuple(int, float)`):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
- eos_token_id (`Union[int, List[int], ms.Tensor]`):
+ eos_token_id (`Union[int, list[int], ms.Tensor]`):
The id(s) of the *end-of-sequence* token.
input_ids_seq_length (`int`):
The length of the input sequence.
@@ -1632,7 +1873,7 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
def __init__(
self,
exponential_decay_length_penalty: tuple[int, float],
- eos_token_id: Union[int, List[int], ms.Tensor],
+ eos_token_id: Union[int, list[int], ms.Tensor],
input_ids_seq_length: int,
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
@@ -1669,6 +1910,33 @@ class LogitNormalization(LogitsProcessor):
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM
+ >>> import mindspore as ms
+ >>> from mindspore import mint
+
+ >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
+
+ >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="np")
+ >>> for key in inputs.keys():
+ >>> inputs[key] = ms.tensor(inputs[key])
+
+ >>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
+ >>> # distribution, summing to 1
+ >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
+ >>> print(mint.allclose(mint.sum(mint.exp(outputs.scores[-1])), ms.Tensor((1.000,)), rtol=1e-4))
+ False
+
+ >>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
+ >>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
+ >>> print(mint.allclose(mint.sum(mint.exp(outputs.scores[-1])), ms.Tensor((1.000,)), rtol=1e-4))
+ True
+ ```
+
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@@ -1691,7 +1959,7 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
[`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are
not generated at the beginning. Originally created for
- [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
+ [Whisper](https://huggingface.co/papers/docs/transformers/model_doc/whisper).
Examples:
@@ -1745,7 +2013,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
r"""
This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so
that they are not generated. Originally created for
- [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
+ [Whisper](https://huggingface.co/papers/docs/transformers/model_doc/whisper).
Examples:
@@ -1797,7 +2065,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
potential tokens.
- See [the paper](https://arxiv.org/abs/2212.04356) for more information.
+ See [the paper](https://huggingface.co/papers/2212.04356) for more information.
Args:
generate_config (`GenerateConfig`):
@@ -1814,7 +2082,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
Examples:
``` python
- >>> from mindone.transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig
+ >>> from mindone.transformers import AutoProcessor, WhisperForConditionalGeneration
+ >>> from transformers import GenerationConfig
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
@@ -1858,11 +2127,13 @@ def __init__(
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
-
- num_forced_ids = (
- len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
- )
- self.begin_index = begin_index or (num_forced_ids + 1)
+ self.begin_index = begin_index
+ if begin_index is None:
+ raise ValueError(
+ "`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` "
+ "must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` "
+ "was `len(generate_config.forced_decoder_ids)`"
+ )
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
@@ -1947,6 +2218,10 @@ def set_inputs(self, inputs):
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
self.inputs["input_features"] = self.inputs.pop("inputs")
+ # Whisper encoder-decoder does not accept the input_ids as input
+ if "input_ids" not in inspect.signature(self.model.forward).parameters:
+ self.inputs.pop("input_ids", None)
+
@property
def no_speech_prob(self):
return self._no_speech_prob
@@ -1985,12 +2260,12 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
- See [the paper](https://arxiv.org/abs/2306.05284) for more information.
+ See [the paper](https://huggingface.co/papers/2306.05284) for more information.
This logits processor is exclusively compatible with
- [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen)
+ [MusicGen](https://huggingface.co/papers/docs/transformers/main/en/model_doc/musicgen)
@@ -2053,7 +2328,7 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
This logits processor is exclusively compatible with
- [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation
+ [Bark](https://huggingface.co/papers/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation
for examples.
@@ -2097,7 +2372,7 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`.
The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch.
- See [the paper](https://arxiv.org/abs/2306.17806) for more information.
+ See [the paper](https://huggingface.co/papers/2306.17806) for more information.
Args:
guidance_scale (`float`):
@@ -2120,7 +2395,8 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
Examples:
```python
- >>> from mindone.transformers import AutoTokenizer, AutoModelForCausalLM
+ >>> from transformers import AutoTokenizer
+ >>> from mindone.transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
@@ -2215,18 +2491,18 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
This logits processor is exclusively compatible with
- [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples.
+ [Bark](https://huggingface.co/papers/docs/transformers/en/model_doc/bark). See the model documentation for examples.
Args:
- eos_token_id (`Union[int, List[int], ms.Tensor]`):
+ eos_token_id (`Union[int, list[int], ms.Tensor]`):
The id(s) of the *end-of-sequence* token.
min_eos_p (`float`, *optional*):
Minimum end of speech threshold.
"""
- def __init__(self, eos_token_id: Union[int, List[int], ms.Tensor], min_eos_p: float):
+ def __init__(self, eos_token_id: Union[int, list[int], ms.Tensor], min_eos_p: float):
if not isinstance(eos_token_id, ms.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
@@ -2266,7 +2542,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
The text generated by this `LogitsProcessor` can be detected using `WatermarkDetector`. See [`~WatermarkDetector.__call__`] for details,
- See [the paper](https://arxiv.org/abs/2306.04634) for more information.
+ See [the paper](https://huggingface.co/papers/2306.04634) for more information.
Args:
vocab_size (`int`):
@@ -2461,7 +2737,7 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
Args:
ngram_len (`int`):
Ngram length.
- keys (`List[int]`):
+ keys (`list[int]`):
A sequence of watermarking keys, one for each depth.
sampling_table_size (`int`):
Size of the sampling table.
@@ -2500,7 +2776,7 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
ngram_len: int,
- keys: List[int],
+ keys: list[int],
sampling_table_size: int,
sampling_table_seed: int,
context_history_size: int,
@@ -2672,7 +2948,7 @@ def compute_ngram_keys(self, ngrams: ms.Tensor) -> ms.Tensor:
ngram keys (batch_size, num_ngrams, depth).
"""
if len(ngrams.shape) != 3:
- raise ValueError("Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but" f" is {ngrams.shape}")
+ raise ValueError(f"Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but is {ngrams.shape}")
if ngrams.shape[2] != self.ngram_len:
raise ValueError(
"Ngrams should be of shape (batch_size, num_ngrams, ngram_len),"
@@ -2860,3 +3136,223 @@ def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) ->
The expected mean g-value for watermarked text.
"""
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))
+
+
+class DiaClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] for classifier free guidance (CFG). Similar to the original
+ `ClassifierFreeGuidanceLogitsProcessor` with some modifications on the overall
+ calculation, e.g. conditioned logits centered, and an additional top k selection
+ option.
+
+
+
+ This logits processor is exclusively compatible with
+ [Dia](https://huggingface.co/docs/transformers/main/en/model_doc/dia)
+
+
+
+ Args:
+ guidance_scale (float):
+ The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
+ Higher guidance scale encourages the model to generate samples that are more closely linked to the input
+ prompt, usually at the expense of poorer quality.
+ guidance_top_k (int, *optional*):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. However, we do not keep
+ the logits of the combined CFG output, but the conditioned output only.
+ """
+
+ def __init__(self, guidance_scale: float, guidance_top_k: Optional[int] = None):
+ if guidance_scale > 1:
+ self.guidance_scale = guidance_scale
+ else:
+ raise ValueError(
+ "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
+ f"{guidance_scale}."
+ )
+
+ self.guidance_top_k = guidance_top_k
+ if self.guidance_top_k is not None and self.guidance_top_k < 1:
+ raise ValueError(
+ f"`guidance_top_k` has to be a strictly positive integer if given, but is {self.guidance_top_k}"
+ )
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor:
+ # simple check to make sure we have compatible batch sizes between our
+ # logits scores (cond + uncond) and input ids (cond only)
+ if scores.shape[0] != 2 * input_ids.shape[0]:
+ raise ValueError(
+ f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
+ f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
+ f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
+ )
+ # Base CFG with center on cond_logits
+ unguided_bsz = scores.shape[0] // 2
+ cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
+ scores_processed = cond_logits + (cond_logits - uncond_logits) * self.guidance_scale
+
+ # Optional CFG top k filtering
+ if self.guidance_top_k is not None:
+ # Create top k based on the combined CFG output
+ _, top_k_indices = mint.topk(scores_processed, k=self.guidance_top_k, dim=-1)
+ top_k_mask = mint.ones_like(scores_processed, dtype=ms.bool_)
+ top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False)
+ # Only return conditioned logits with top k
+ scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf"))
+
+ return scores_processed
+
+
+class DiaEOSChannelFilterLogitsProcessor(LogitsProcessor):
+ r"""Specialized processor that ensures certain properties around EOS sampling:
+ 1. Only channel 0 can generate EOS
+ 2. If channel 0 has EOS with highest logit, it will be the only candidate
+ 3. If channel 0 has EOS not with highest logit, it will be suppressed
+
+ 2. and 3. are especially important in contexts where we allow sampling to guarantee the
+ respective tokens to be (not) sampled.
+
+
+
+ This logits processor is exclusively compatible with
+ [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
+
+
+
+ Args:
+ num_channels (`int`):
+ Number of audio codebooks. Simplifies access to the first channel on the logits.
+ eos_token_id (`int`):
+ The id of *end-of-sequence* token.
+ """
+
+ def __init__(self, num_channels: int, eos_token_id: int):
+ if num_channels < 1:
+ raise ValueError(f"Audio codebooks need at least one channel, but found {num_channels} channels.")
+ if eos_token_id < 1:
+ raise ValueError(f"Expected `eos_token_id` to be a positive integer, found {eos_token_id} instead.")
+
+ self.num_channels = num_channels
+ self.eos_id = eos_token_id
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor:
+ # Reshape for easier channel indexing [B, C, V]
+ scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
+
+ # EOS filter
+ # 1. Condition: Only the first channel can generate the EOS token
+ # Side condition of disabling generation of special tokens (e.g. audio pad, bos, ...)
+ # (Assumes them to be greater than audio eos token position)
+ scores[:, 1:, self.eos_id :] = mint.full_like(
+ scores[:, 1:, self.eos_id :],
+ fill_value=-float("inf"),
+ )
+ scores[:, 0, self.eos_id + 1 :] = mint.full_like(
+ scores[:, 0, self.eos_id + 1 :],
+ fill_value=-float("inf"),
+ )
+
+ # 2+3 Conditions: Force/Suppress EOS if (not) highest logit
+ # Reshape back to original shape
+ scores = scores.view(-1, scores.shape[-1])
+
+ # Sample highest tokens
+ top_logit_indices = mint.argmax(scores, dim=-1)
+
+ # 2. Force EOS
+ eos_highest_mask = top_logit_indices == self.eos_id
+ mask_eos_highest = mint.zeros_like(scores, dtype=ms.bool_)
+ mask_eos_highest[eos_highest_mask, : self.eos_id] = True
+ scores = scores.masked_fill(mask_eos_highest, -float("inf"))
+
+ # 3. Suppress EOS
+ eos_not_highest_mask = top_logit_indices != self.eos_id
+ mask_eos_unless_highest = mint.zeros_like(scores, dtype=ms.bool_)
+ mask_eos_unless_highest[eos_not_highest_mask, self.eos_id] = True
+ scores = scores.masked_fill(mask_eos_unless_highest, -float("inf"))
+
+ return scores
+
+
+class DiaEOSDelayPatternLogitsProcessor(LogitsProcessor):
+ r"""Special logits processor to handle the generation of the EOS token in Dia.
+ This is due to the fact that Dia does not allow the generation of EOS in all
+ channels except the first channel (C0).
+
+ Hence, based on the delay pattern, an EOS is forced after the respective delays
+ in the channels. For example, if the delay pattern is [0, 2, 3, 4]:
+
+ s s+1 s+2 s+3 s+4 s+5 ...
+ | | | | | |
+ C0: EOS PAD PAD PAD PAD PAD ...
+ C1: x x EOS PAD PAD PAD ...
+ C2: x x x EOS PAD PAD ...
+ C3: x x x x EOS PAD ...
+
+ If the first channel generated EOS at step s, channels Cx are forced to generate
+ theirs at the respective delays (s+2, s+3, s+4). Subsequent padding tokens are
+ handled by the `EosTokenCriteria` when an EOS has been detected.
+
+
+
+ This logits processor is exclusively compatible with
+ [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
+
+
+
+ Args:
+ delay_pattern (`List[int]`):
+ The delays per channel in the audio codebooks.
+ eos_token_id (`int`):
+ The id of *end-of-sequence* token.
+ max_generation_len (`int`):
+ The max sequence length that can be generated.
+ device (`str`, *optional*, defaults to `"cpu"`):
+ The device to allocate the tensors on.
+ """
+
+ def __init__(self, delay_pattern: list[int], eos_token_id: int, max_generation_len: int):
+ self.num_channels = len(delay_pattern)
+ # Update during first iteration
+ self.active_batches = None
+ self.delay_pattern = ms.tensor(delay_pattern, dtype=ms.int32)[None, :]
+ self.eos_token_id = eos_token_id
+ self.max_generation_len = max_generation_len - max(delay_pattern) - 1
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor) -> ms.Tensor:
+ # Reshape for easier channel indexing [B, C, V]
+ scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
+
+ # Initialize / expand values on first iteration
+ if self.active_batches is None:
+ self.delay_pattern = self.delay_pattern.repeat(scores.shape[0], 1)
+ self.active_batches = mint.zeros(size=(scores.shape[0],), dtype=ms.bool_)
+
+ # Check if eos has been generated in any batch
+ channel_generated_eos = mint.argmax(scores, dim=-1)[:, 0] == self.eos_token_id
+ # Check if max len has been reached
+ reached_max_len = input_ids.shape[1] == self.max_generation_len
+
+ # Update active batches
+ self.active_batches |= channel_generated_eos
+ self.active_batches |= reached_max_len
+
+ # Find channels that need to force eos
+ forced_eos_channels = self.active_batches[:, None] & (self.delay_pattern == 0)
+ # Use indexing to avoid issues on all `False` by having empty tensors in that case
+ idx_bsz, idx_channel = forced_eos_channels.nonzero(as_tuple=True)
+
+ # Force eos if delay is kicking in
+ scores[idx_bsz, idx_channel, :] = -float("inf")
+ scores[idx_bsz, idx_channel, self.eos_token_id] = 0.0
+
+ # Reshape back to [B * C, V]
+ scores = scores.reshape(-1, scores.shape[-1])
+
+ # Update amount of delay left for each channel
+ self.delay_pattern -= self.active_batches[:, None].int()
+
+ return scores
diff --git a/mindone/transformers/generation/stopping_criteria.py b/mindone/transformers/generation/stopping_criteria.py
index 55e50f0d97..9b4937d31a 100644
--- a/mindone/transformers/generation/stopping_criteria.py
+++ b/mindone/transformers/generation/stopping_criteria.py
@@ -5,7 +5,7 @@
from abc import ABC
from collections import OrderedDict
from copy import deepcopy
-from typing import List, Optional, Union
+from typing import Optional, Union
import numpy as np
from transformers import PreTrainedTokenizerBase
@@ -26,14 +26,14 @@
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Args:
- input_ids (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, sequence_length)`):
+ input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
- scores (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, config.vocab_size)`):
+ scores (`ms.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
@@ -41,8 +41,9 @@
Additional stopping criteria specific kwargs.
Return:
- `Union[ms.Tensor, numpy.ndarray]`. (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, 1)`), where `True` indicates we stop generation
- for a particular row, `True` indicates we should continue.
+ `ms.Tensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`):
+ `True` indicates we stop generation for a particular row.
+ `False` indicates we should continue.
"""
@@ -79,7 +80,7 @@ def __init__(self, max_length: int, max_position_embeddings: Optional[int] = Non
def __call__(
self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray], **kwargs
) -> Union[ms.Tensor, np.ndarray]:
- cur_len = input_ids.shape[-1]
+ cur_len = input_ids.shape[1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(
@@ -228,7 +229,7 @@ class StopStringCriteria(StoppingCriteria):
Args:
tokenizer (`PreTrainedTokenizer`):
The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences)
- stop_strings (`Union[str, List[str]]`):
+ stop_strings (`Union[str, list[str]]`):
A list of strings that should end generation. If a string is passed, it will be treated like a
list with a single element.
@@ -258,7 +259,7 @@ class StopStringCriteria(StoppingCriteria):
```
"""
- def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]):
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, list[str]]):
if isinstance(stop_strings, str):
stop_strings = [stop_strings]
self.stop_strings: tuple[str, ...] = tuple(stop_strings)
@@ -317,7 +318,7 @@ def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"):
@staticmethod
def _stop_string_get_matching_positions(
token_list, token_indices, stop_strings
- ) -> tuple[dict[str, dict[str, List[int]]], dict[str, dict[str, List[int]]]]:
+ ) -> tuple[dict[str, dict[str, list[int]]], dict[str, dict[str, list[int]]]]:
"""This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can
validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the
token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters
@@ -472,11 +473,11 @@ class EosTokenCriteria(StoppingCriteria):
By default, it uses the `model.generation_config.eos_token_id`.
Args:
- eos_token_id (`Union[int, List[int], ms.Tensor]`):
+ eos_token_id (`Union[int, list[int], ms.Tensor]`):
The id(s) of the *end-of-sequence* token.
"""
- def __init__(self, eos_token_id: Union[int, List[int], ms.Tensor]):
+ def __init__(self, eos_token_id: Union[int, list[int], ms.Tensor]):
# to list
if not isinstance(eos_token_id, ms.Tensor):
if isinstance(eos_token_id, int):
diff --git a/mindone/transformers/generation/utils.py b/mindone/transformers/generation/utils.py
index 091afb8713..7338e9330e 100644
--- a/mindone/transformers/generation/utils.py
+++ b/mindone/transformers/generation/utils.py
@@ -18,21 +18,31 @@
# limitations under the License.
import copy
import inspect
+import os
import time
import warnings
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import numpy as np
+from huggingface_hub import file_exists
from packaging import version
-from transformers import logging
-from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
+from transformers import PretrainedConfig, logging
+from transformers.dynamic_module_utils import (
+ check_python_requirements,
+ get_cached_module_file,
+ get_class_in_module,
+ resolve_trust_remote_code,
+)
+from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
+from transformers.generation.configuration_utils import CompileConfig, GenerationConfig, GenerationMode
from transformers.tokenization_utils import ExtensionsTrie
from transformers.utils.generic import ModelOutput
import mindspore as ms
import mindspore.numpy as mnp
from mindspore import mint, ops
+from mindspore.mint.nn import functional as F
from mindone.transformers.cache_utils import (
Cache,
@@ -47,6 +57,7 @@
init_static_cache,
reset,
)
+from mindone.transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from mindone.transformers.generation.candidate_generator import (
AssistantVocabTranslatorCache,
AssistedCandidateGenerator,
@@ -55,6 +66,8 @@
EarlyExitCandidateGenerator,
PromptLookupCandidateGenerator,
UniversalSpeculativeDecodingGenerator,
+ _prepare_attention_mask,
+ _prepare_token_type_ids,
)
from mindone.transformers.generation.logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
@@ -93,8 +106,11 @@
StoppingCriteriaList,
StopStringCriteria,
)
+from mindone.transformers.masking_utils import create_masks_for_generate
+from mindone.transformers.mindspore_adapter import dtype_to_min
from mindone.transformers.mindspore_adapter.paged_attention_block_tables import BlockTables
from mindone.transformers.mindspore_adapter.select_operator import get_multinomial_op
+from mindone.transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
@@ -135,33 +151,33 @@ class GenerateDecoderOnlyOutput(ModelOutput):
if all batches finished early due to the `eos_token_id`.
scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(mindspore.Tensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
- Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
- tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
+ Usually a tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
+ tensor). The first tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
- sequences: ms.Tensor = None
- scores: Optional[Tuple[ms.Tensor]] = None
- logits: Optional[Tuple[ms.Tensor]] = None
- attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
- hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None
- past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None
+ sequences: ms.Tensor
+ scores: Optional[tuple[ms.Tensor]] = None
+ logits: Optional[tuple[ms.Tensor]] = None
+ attentions: Optional[tuple[tuple[ms.Tensor]]] = None
+ hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None
+ past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None
@dataclass
@@ -175,45 +191,45 @@ class GenerateEncoderDecoderOutput(ModelOutput):
if all batches finished early due to the `eos_token_id`.
scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
encoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
+ tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True`):
- Tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of
+ tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(mindspore.Tensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
- Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
- tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
+ Usually a tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
+ tensor). The first tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
- sequences: ms.Tensor = None
- scores: Optional[Tuple[ms.Tensor]] = None
- logits: Optional[Tuple[ms.Tensor]] = None
- encoder_attentions: Optional[Tuple[ms.Tensor]] = None
- encoder_hidden_states: Optional[Tuple[ms.Tensor]] = None
- decoder_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
- cross_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
- decoder_hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None
- past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None
+ sequences: ms.Tensor
+ scores: Optional[tuple[ms.Tensor]] = None
+ logits: Optional[tuple[ms.Tensor]] = None
+ encoder_attentions: Optional[tuple[ms.Tensor]] = None
+ encoder_hidden_states: Optional[tuple[ms.Tensor]] = None
+ decoder_attentions: Optional[tuple[tuple[ms.Tensor]]] = None
+ cross_attentions: Optional[tuple[tuple[ms.Tensor]]] = None
+ decoder_hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None
+ past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None
@dataclass
@@ -230,20 +246,20 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput):
scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`):
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
+ tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
beam_indices (`ms.Tensor`, *optional*, returned when `output_scores=True`):
Beam indices of generated token id at each generation step. `ms.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(ms.Tensor)))`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
@@ -252,12 +268,12 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput):
sequences: ms.Tensor = None
sequences_scores: Optional[ms.Tensor] = None
- scores: Optional[Tuple[ms.Tensor]] = None
- logits: Optional[Tuple[ms.Tensor]] = None
+ scores: Optional[tuple[ms.Tensor]] = None
+ logits: Optional[tuple[ms.Tensor]] = None
beam_indices: Optional[ms.Tensor] = None
- attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
- hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None
- past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None
+ attentions: Optional[tuple[tuple[ms.Tensor]]] = None
+ hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None
+ past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None
@dataclass
@@ -274,48 +290,70 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`):
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
+ tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ at each generation step. tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
beam_indices (`ms.Tensor`, *optional*, returned when `output_scores=True`):
Beam indices of generated token id at each generation step. `ms.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
encoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
+ tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True`):
- Tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of
+ tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
sequence_length)`.
cross_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`ms.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(ms.Tensor)))`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
- sequences: ms.Tensor = None
+ sequences: ms.Tensor
sequences_scores: Optional[ms.Tensor] = None
- scores: Optional[Tuple[ms.Tensor]] = None
- logits: Optional[Tuple[ms.Tensor]] = None
+ scores: Optional[tuple[ms.Tensor]] = None
+ logits: Optional[tuple[ms.Tensor]] = None
beam_indices: Optional[ms.Tensor] = None
- encoder_attentions: Optional[Tuple[ms.Tensor]] = None
- encoder_hidden_states: Optional[Tuple[ms.Tensor]] = None
- decoder_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
- cross_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
- decoder_hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None
- past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None
+ encoder_attentions: Optional[tuple[ms.Tensor]] = None
+ encoder_hidden_states: Optional[tuple[ms.Tensor]] = None
+ decoder_attentions: Optional[tuple[tuple[ms.Tensor]]] = None
+ cross_attentions: Optional[tuple[tuple[ms.Tensor]]] = None
+ decoder_hidden_states: Optional[tuple[tuple[ms.Tensor]]] = None
+ past_key_values: Optional[tuple[tuple[tuple[ms.Tensor]]]] = None
+
+
+# TODO (joao): remove the equivalent classes and typing shortcuts below in v5
+# Equivalent classes (kept for retrocompatibility purposes)
+GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
+ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
+SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput
+
+ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
+GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
+SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput
+
+BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
+BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
+BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
+BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
+
+GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
+SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
+BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
+BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
+ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
# Typing shortcuts
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
@@ -355,17 +393,176 @@ class GenerationMixin:
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""
+ def load_custom_generate(
+ self,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ trust_remote_code: Optional[bool] = None,
+ **kwargs,
+ ) -> Callable:
+ """
+ Loads and returns a custom generate function, given a model repo.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ Can be either:
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ trust_remote_code (`bool`, *optional*):
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
+ execute code present on the Hub on your local machine.
+ **kwargs:
+ Additional keyword arguments for remote code loading.
+
+ Raises:
+ OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.
+
+ Returns:
+ A callable that can be used to generate text.
+ """
+ # Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError
+ is_local_code = os.path.exists(pretrained_model_name_or_path)
+ has_custom_generate_folder = True
+ if is_local_code:
+ if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")):
+ has_custom_generate_folder = False
+ else:
+ if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"):
+ has_custom_generate_folder = False
+
+ if not has_custom_generate_folder:
+ raise OSError(
+ f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
+ "`generate.py` file, can't load the custom generate function."
+ )
+
+ # Handle opt-in `trust_remote_code` and related exceptions
+ error_message = (
+ f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
+ "the default `generate` method."
+ )
+ resolve_trust_remote_code(
+ trust_remote_code,
+ pretrained_model_name_or_path,
+ has_local_code=is_local_code,
+ has_remote_code=not is_local_code,
+ error_message=error_message,
+ )
+
+ # Load the custom generate function
+ check_python_requirements(
+ pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
+ )
+ module = get_cached_module_file(
+ pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
+ )
+ custom_generate_function = get_class_in_module("generate", module)
+ return custom_generate_function
+
+ def _cache_dependant_input_preparation(
+ self,
+ input_ids: ms.Tensor,
+ inputs_embeds: Optional[ms.Tensor],
+ cache_position: Optional[ms.Tensor],
+ ) -> tuple[ms.Tensor, ms.Tensor]:
+ """
+ Generic cache-dependent input preparation
+ The code is put in a separate function to allow granular unit testing
+ as it needs a different implementation to be exportable.
+
+ If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
+ - Exception 1: when passing input_embeds, input_ids may be missing entries
+ - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
+ - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
+ - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
+ generate the first token for each sequence. Later use the generated Input ids for continuation.
+
+ The current implementation does not rely on ``self`` and could be
+ a class method. It is left as a standard method to be easily rewritten.
+ """
+ # fixme there is no implementation for torch dynamo exporting
+ if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
+ inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
+ elif inputs_embeds is not None or (cache_position[-1] >= input_ids.shape[1]): # Exception 1 # Exception 3
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ return inputs_embeds, input_ids
+
+ def _cache_dependant_input_preparation_exporting(
+ self,
+ input_ids: ms.Tensor,
+ inputs_embeds: Optional[ms.Tensor],
+ cache_position: Optional[ms.Tensor],
+ ) -> tuple[ms.Tensor, ms.Tensor]:
+ """
+ This method implements method ``_cache_dependant_input_preparation``
+ with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
+ The code is put in a separate function to allow granular unit testing.
+ """
+ if inputs_embeds is None:
+ input_ids = input_ids[:, cache_position]
+ else:
+ # This is the code we need to implemented with torch.cond.
+ # if input_ids.shape[1] == 0:
+ # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
+ # else:
+ # if cache_position[-1] >= input_ids.shape[1]:
+ # input_ids = input_ids[:, -cache_position.shape[0] :]
+ # else:
+ # if input_ids.shape[1] != cache_position.shape[0]:
+ # input_ids = input_ids[:, cache_position]
+ def branch_1(inputs_embeds, cache_position):
+ return inputs_embeds[:, -cache_position.shape[0] :]
+
+ def branch_2(input_ids, cache_position):
+ return input_ids[:, -cache_position.shape[0] :]
+
+ def branch_3(input_ids, cache_position):
+ return input_ids[:, cache_position]
+
+ inputs_embeds, input_ids = mint.cond(
+ input_ids.shape[1] == 0,
+ (
+ lambda input_ids, inputs_embeds, cache_position: (
+ branch_1(inputs_embeds, cache_position),
+ input_ids,
+ )
+ ),
+ (
+ lambda input_ids, inputs_embeds, cache_position: (
+ inputs_embeds,
+ mint.cond(
+ cache_position[-1] >= input_ids.shape[1],
+ branch_2,
+ lambda input_ids, cache_position: (
+ mint.cond(
+ input_ids.shape[1] != cache_position.shape[0],
+ branch_3,
+ (lambda input_ids, cache_position: input_ids),
+ [input_ids, cache_position],
+ )
+ ),
+ [input_ids, cache_position],
+ ),
+ )
+ ),
+ [input_ids, inputs_embeds, cache_position],
+ )
+ return inputs_embeds, input_ids
+
def prepare_inputs_for_generation(
self,
input_ids,
- past_key_values: Union[Cache, Tuple] = None,
+ past_key_values: Union[Cache, tuple] = None,
attention_mask: Optional[ms.Tensor] = None,
inputs_embeds: Optional[ms.Tensor] = None,
cache_position: Optional[ms.Tensor] = None,
**kwargs,
):
"""
- Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
+ Prepare the model inputs for generation. It includes operations like computing the 4D attention mask or
slicing inputs given the existing cache.
See the forward pass in the model documentation for expected arguments (different models might have different
@@ -382,11 +579,11 @@ def prepare_inputs_for_generation(
# (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
elif cache_position is None:
past_length = (
- get_seq_length(past_key_values, dynamic=self._supports_default_dynamic_input())
+ get_seq_length(past_key_values, dynamic=not self._supports_default_jit())
if past_key_values is not None
else 0
)
- cache_position = ops.arange(past_length, input_ids.shape[1], dtype=ms.int32)
+ cache_position = mint.arange(past_length, input_ids.shape[1], dtype=ms.int32)
if kwargs["use_cache"]:
model_inputs["cache_position"] = cache_position
@@ -419,7 +616,7 @@ def prepare_inputs_for_generation(
model_inputs["inputs_embeds"] = inputs_embeds
else:
# For static shape, padding input_id to max_len when no cache, in prefill-stage
- if not self._supports_default_dynamic_input() and past_key_values is None:
+ if self._supports_default_jit() and past_key_values is None:
pad_len = max(0, attention_mask.shape[1] - input_ids.shape[1])
input_ids = mint.nn.functional.pad(input_ids, (0, pad_len), value=0)
model_inputs[input_ids_key] = input_ids
@@ -450,7 +647,7 @@ def prepare_inputs_for_generation(
_past_key_values = past_key_values
if (
isinstance(past_key_values, (tuple, list))
- and get_seq_length(past_key_values, dynamic=self._supports_default_dynamic_input()) == 0
+ and get_seq_length(past_key_values, dynamic=not self._supports_default_jit()) == 0
):
_past_key_values = None
@@ -460,7 +657,7 @@ def prepare_inputs_for_generation(
if model_inputs.get("inputs_embeds") is not None
else model_inputs[input_ids_key].shape[1]
)
- if self._supports_default_dynamic_input() or attention_mask is None:
+ if not self._supports_default_jit() or attention_mask is None:
model_input = model_input[:, -current_input_length:]
else: # static shape input
cur_len = attention_mask.sum(-1).max()
@@ -487,11 +684,19 @@ def prepare_inputs_for_generation(
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
)
if causal_mask_creation_function is None:
- logger.warning_once(
- f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
- "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
- "writing code, see Llama for an example implementation. If you're a user, please report this "
- "issue on GitHub."
+ token_type_ids = model_inputs.get("token_type_ids", None)
+ position_ids = model_inputs.get(position_ids_key, None)
+ # Some models may overwrite the general one
+ causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
+ attention_mask = causal_mask_creation_function(
+ config=self.config,
+ # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
+ input_embeds=mint.empty((batch_size, sequence_length), dtype=self.dtype),
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
)
else:
attention_mask = causal_mask_creation_function(
@@ -602,8 +807,8 @@ def _prepare_model_inputs(
self,
inputs: Optional[ms.Tensor] = None,
bos_token_id: Optional[ms.Tensor] = None,
- model_kwargs: Optional[Dict[str, ms.Tensor]] = None,
- ) -> Tuple[ms.Tensor, Optional[str], Dict[str, ms.Tensor]]:
+ model_kwargs: Optional[dict[str, ms.Tensor]] = None,
+ ) -> tuple[ms.Tensor, Optional[str], dict[str, ms.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
@@ -638,7 +843,9 @@ def _prepare_model_inputs(
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
- if not self.config.is_encoder_decoder:
+ if model_kwargs["inputs_embeds"] is None:
+ model_kwargs.pop("inputs_embeds")
+ elif not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
@@ -653,10 +860,11 @@ def _prepare_model_inputs(
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs
)
+ inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
- inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
+ inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
@@ -666,7 +874,7 @@ def _maybe_initialize_input_ids_for_generation(
self,
inputs: Optional[ms.Tensor] = None,
bos_token_id: Optional[ms.Tensor] = None,
- model_kwargs: Optional[Dict[str, ms.Tensor]] = None,
+ model_kwargs: Optional[dict[str, ms.Tensor]] = None,
) -> ms.Tensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
@@ -676,7 +884,7 @@ def _maybe_initialize_input_ids_for_generation(
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.shape[:-1]
- return ops.ones(shape, dtype=ms.int32) * -100
+ return mint.ones(shape, dtype=ms.int32) * -100
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
@@ -687,18 +895,18 @@ def _maybe_initialize_input_ids_for_generation(
break
if "inputs_embeds" in model_kwargs:
- return ops.ones((batch_size, 0), dtype=ms.int32)
+ return mint.ones((batch_size, 0), dtype=ms.int32)
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
- return ops.ones((batch_size, 1), dtype=ms.int32) * bos_token_id
+ return mint.ones((batch_size, 1), dtype=ms.int32) * bos_token_id
def _prepare_attention_mask_for_generation(
self,
inputs_tensor: ms.Tensor,
generation_config: GenerationConfig,
- model_kwargs: Dict[str, Any],
+ model_kwargs: dict[str, Any],
) -> ms.Tensor:
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
@@ -708,7 +916,7 @@ def _prepare_attention_mask_for_generation(
inputs_tensor = model_kwargs["input_ids"]
# No information for attention mask inference -> return default attention mask
- default_attention_mask = ops.ones(inputs_tensor.shape[:2], dtype=ms.int32)
+ default_attention_mask = mint.ones(inputs_tensor.shape[:2], dtype=ms.int32)
if pad_token_id is None:
return default_attention_mask
@@ -736,7 +944,7 @@ def _prepare_encoder_decoder_kwargs_for_generation(
model_kwargs,
model_input_name: Optional[str],
generation_config: GenerationConfig,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
# 1. get encoder
encoder = self.get_encoder()
@@ -768,10 +976,10 @@ def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
- model_kwargs: Dict[str, ms.Tensor],
+ model_kwargs: dict[str, ms.Tensor],
decoder_start_token_id: ms.Tensor,
**ignore_kwargs,
- ) -> Tuple[ms.Tensor, Dict[str, ms.Tensor]]:
+ ) -> tuple[ms.Tensor, dict[str, ms.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
@@ -790,7 +998,7 @@ def _prepare_decoder_input_ids_for_generation(
)
decoder_start_token_id = decoder_start_token_id.view(-1, 1)
else:
- decoder_start_token_id = mint.ones((batch_size, 1), dtype=ms.int64) * decoder_start_token_id
+ decoder_start_token_id = mint.ones((batch_size, 1), dtype=ms.int32) * decoder_start_token_id
# 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
# no user input -> use decoder_start_token_id as decoder_input_ids
@@ -808,12 +1016,12 @@ def _prepare_decoder_input_ids_for_generation(
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
- decoder_input_ids = ops.cat([decoder_start_token_id, decoder_input_ids], axis=-1)
+ decoder_input_ids = mint.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
- decoder_attention_mask = ops.cat(
- (ops.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
- axis=-1,
+ decoder_attention_mask = mint.cat(
+ (mint.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
+ dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
@@ -825,7 +1033,7 @@ def _expand_inputs_for_generation(
is_encoder_decoder: bool = False,
input_ids: Optional[ms.Tensor] = None,
**model_kwargs,
- ) -> Tuple[ms.Tensor, Dict[str, Any]]:
+ ) -> tuple[ms.Tensor, dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
if expand_size == 1:
return input_ids, model_kwargs
@@ -860,10 +1068,10 @@ def _expand_dict_for_generation(dict_to_expand):
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
- model_kwargs: Dict[str, Any],
+ model_kwargs: dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
# update past_key_values keeping its naming used in model code
for possible_cache_name in ALL_CACHE_NAMES:
if possible_cache_name in outputs:
@@ -878,16 +1086,16 @@ def _update_model_kwargs_for_generation(
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = ops.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1)
+ model_kwargs["token_type_ids"] = mint.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
- if self._supports_default_dynamic_input():
- model_kwargs["attention_mask"] = ops.cat(
- [attention_mask, ops.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype)], axis=-1
+ if not self._supports_default_jit():
+ model_kwargs["attention_mask"] = mint.cat(
+ [attention_mask, mint.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype)], dim=-1
)
else: # update static attention mask
cur_lens = attention_mask.sum(-1)
@@ -903,18 +1111,18 @@ def _update_model_kwargs_for_generation(
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
- model_kwargs["decoder_attention_mask"] = ops.cat(
+ model_kwargs["decoder_attention_mask"] = mint.cat(
[
decoder_attention_mask,
- ops.ones((decoder_attention_mask.shape[0], 1), dtype=decoder_attention_mask.dtype),
+ mint.ones((decoder_attention_mask.shape[0], 1), dtype=decoder_attention_mask.dtype),
],
- axis=-1,
+ dim=-1,
)
if model_kwargs.get("use_cache", True):
# the first step for static shape
if (
- not self._supports_default_dynamic_input()
+ self._supports_default_jit()
and model_kwargs.get("attention_mask", None) is not None
and model_kwargs["attention_mask"].shape[-1] == model_kwargs["cache_position"].shape[0]
):
@@ -927,7 +1135,7 @@ def _update_model_kwargs_for_generation(
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
else:
past_positions = model_kwargs.pop("cache_position")
- if self._supports_default_dynamic_input() or model_kwargs.get("attention_mask", None) is None:
+ if not self._supports_default_jit() or model_kwargs.get("attention_mask", None) is None:
cur_idx = int(past_positions[-1]) + 1
new_positions = mint.arange(cur_idx, cur_idx + num_new_tokens, dtype=past_positions.dtype)
model_kwargs["cache_position"] = mint.cat((past_positions, new_positions))
@@ -938,7 +1146,7 @@ def _update_model_kwargs_for_generation(
cache_position = mint.cat((past_positions[:cur_idx], new_positions))
if cache_position.shape[-1] < max_len: # pad to max_len
cache_position = mint.cat(
- (cache_position, ops.zeros((max_len - cache_position.shape[-1]), dtype=cache_position.dtype))
+ (cache_position, mint.zeros((max_len - cache_position.shape[-1]), dtype=cache_position.dtype))
)
model_kwargs["cache_position"] = cache_position
@@ -959,7 +1167,7 @@ def _get_candidate_generator(
logits_processor: LogitsProcessorList,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
- model_kwargs: Dict,
+ model_kwargs: dict,
) -> CandidateGenerator:
"""
Returns the candidate generator to be used in `assisted_generation`
@@ -985,7 +1193,11 @@ def _get_candidate_generator(
elif different_tokenizers:
if generation_config.do_sample is True:
atm_translator = AssistantVocabTranslatorCache.get_translator(
- target_tokenizer, assistant_tokenizer, self.config.vocab_size
+ target_tokenizer,
+ assistant_tokenizer,
+ self.config.get_text_config().vocab_size,
+ assistant_model=assistant_model,
+ assistant_prune_lm_head=True, # prune LM head of assistant model
)
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
@@ -1027,11 +1239,11 @@ def _get_candidate_generator(
def _get_logits_processor(
self,
generation_config: GenerationConfig,
- input_ids_seq_length: int,
- encoder_input_ids: ms.Tensor,
- prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], List[int]],
- logits_processor: Optional[LogitsProcessorList],
- model_kwargs: Optional[Dict[str, Any]] = None,
+ input_ids_seq_length: Optional[int] = None,
+ encoder_input_ids: ms.Tensor = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, ms.Tensor], list[int]]] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ model_kwargs: Optional[dict[str, Any]] = None,
negative_prompt_ids: Optional[ms.Tensor] = None,
negative_prompt_attention_mask: Optional[ms.Tensor] = None,
) -> LogitsProcessorList:
@@ -1041,6 +1253,8 @@ def _get_logits_processor(
"""
# instantiate processors list
processors = LogitsProcessorList()
+ if logits_processor is None:
+ logits_processor = []
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
processors.append(
@@ -1110,7 +1324,7 @@ def _get_logits_processor(
)
if (
generation_config.min_length is not None
- and generation_config._eos_token_tensor is not None
+ and getattr(generation_config, "_eos_token_tensor", None) is not None
and generation_config.min_length > 0
):
processors.append(
@@ -1121,7 +1335,7 @@ def _get_logits_processor(
)
if (
generation_config.min_new_tokens is not None
- and generation_config._eos_token_tensor is not None
+ and getattr(generation_config, "_eos_token_tensor", None) is not None
and generation_config.min_new_tokens > 0
):
processors.append(
@@ -1181,13 +1395,6 @@ def _get_logits_processor(
)
)
- # Fixme
- # if generation_config.forced_decoder_ids is not None:
- # raise ValueError(
- # "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
- # "in favour of `input_ids` or `decoder_input_ids` respectively.",
- # )
-
processors = self._merge_criteria_processor_list(processors, logits_processor)
# Processors previously known as `LogitsWarpers`, only applied with sampling strategies
@@ -1236,7 +1443,9 @@ def _get_logits_processor(
# Watermarking should be after all logits processing is finished (see #34630)
if generation_config.watermarking_config is not None:
- processors.append(generation_config.watermarking_config.construct_processor(self.config.vocab_size))
+ processors.append(
+ generation_config.watermarking_config.construct_processor(self.config.get_text_config().vocab_size)
+ )
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
@@ -1291,7 +1500,7 @@ def _merge_criteria_processor_list(
Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
processor/criteria is present on both lists, use the user-defined one.
- (Note: up to v4.49.0, this funtion threw an exception is the same logit processor was found twice.)
+ (Note: up to v4.49.0, this function threw an exception is the same logit processor was found twice.)
"""
if len(custom_list) == 0:
return default_list
@@ -1331,16 +1540,16 @@ def compute_transition_scores(
used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time.
Parameters:
- sequences (`torch.LongTensor`):
+ sequences (`ms.Tensor`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
shorter if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)`):
+ scores (`tuple(ms.Tensor)`):
Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
+ tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
- beam_indices (`torch.LongTensor`, *optional*):
- Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
+ beam_indices (`ms.Tensor`, *optional*):
+ Beam indices of generated token id at each generation step. `ms.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at
generate-time.
normalize_logits (`bool`, *optional*, defaults to `False`):
@@ -1414,7 +1623,7 @@ def compute_transition_scores(
# 3. Optionally normalize the logits (across the vocab dimension)
if normalize_logits:
- scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1])
+ scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1])
scores = mint.nn.functional.log_softmax(scores, dim=1)
scores = scores.reshape(-1, scores.shape[-1])
@@ -1428,7 +1637,7 @@ def compute_transition_scores(
beam_indices[beam_indices_mask] = 0
# 6. multiply beam_indices with vocab size to gather correctly from scores
- beam_sequence_indices = beam_indices * self.config.vocab_size
+ beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size
# 7. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length
@@ -1495,15 +1704,8 @@ def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer):
f"to `generate()` {doc_reference}."
)
- def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
+ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
- # If a `Cache` instance is passed, checks whether the model is compatible with it
- if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
- raise ValueError(
- f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
- "check the model documentation for supported cache formats."
- )
-
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
@@ -1652,8 +1854,8 @@ def _prepare_generated_length(
return generation_config
def _prepare_generation_config(
- self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
- ) -> Tuple[GenerationConfig, Dict]:
+ self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
+ ) -> tuple[GenerationConfig, dict]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs. This
function handles retrocompatibility with respect to configuration files.
@@ -1689,6 +1891,8 @@ def _prepare_generation_config(
generation_config = self.generation_config
using_model_generation_config = True
+ # `torch.export.export` usually raises an exception if it is called
+ # with ``strict=True``. deepcopy can only be processed if ``strict=False``.
generation_config = copy.deepcopy(generation_config)
if not using_model_generation_config:
@@ -1700,16 +1904,25 @@ def _prepare_generation_config(
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
):
modified_values = {}
- default_generation_config = GenerationConfig()
- for key, default_value in default_generation_config.__dict__.items():
+ global_default_generation_config = GenerationConfig()
+ model_generation_config = self.generation_config
+ # we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
+ for key, model_gen_config_value in model_generation_config.__dict__.items():
if key.startswith("_") or key == "transformers_version": # metadata
continue
- custom_gen_config_value = getattr(generation_config, key)
- model_gen_config_value = getattr(self.generation_config, key)
- if custom_gen_config_value == default_value and model_gen_config_value != default_value:
+ global_default_value = getattr(global_default_generation_config, key, None)
+ custom_gen_config_value = getattr(generation_config, key, None)
+ if (
+ custom_gen_config_value == global_default_value
+ and model_gen_config_value != global_default_value
+ ):
modified_values[key] = model_gen_config_value
setattr(generation_config, key, model_gen_config_value)
- if len(modified_values) > 0:
+ # edge case: we may set `temperature=0.0` and `do_sample=False`, but the model defaults to
+ # `do_sample=True`
+ if generation_config.temperature == 0.0:
+ generation_config.do_sample = False
+ if use_model_defaults is None and len(modified_values) > 0:
logger.warning_once(
f"`generation_config` default values have been modified to match model-specific defaults: "
f"{modified_values}. If this is not desired, please set these values explicitly."
@@ -1731,6 +1944,8 @@ def _prepare_generation_config(
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"]:
+ return model_kwargs
# the lines below are equivalent to `mint.arange` [0,1,2,3, .., input_shape-1]
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
cache_position = mint.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=ms.int32).cumsum(0) - 1
@@ -1745,7 +1960,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
cache = model_kwargs["past_key_values"]
past_length = 0
if not isinstance(cache, Cache):
- past_length = get_seq_length(cache, dynamic=self._supports_default_dynamic_input())
+ past_length = get_seq_length(cache, dynamic=not self._supports_default_jit())
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
@@ -1753,14 +1968,13 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
cache_position = cache_position[past_length:]
# for static input, cache fallback to static shape
- if not self._supports_default_dynamic_input() and model_kwargs.get("attention_mask", None) is not None:
+ if self._supports_default_jit() and model_kwargs.get("attention_mask", None) is not None:
attention_mask = model_kwargs["attention_mask"]
cur_len = int(attention_mask.sum(-1).max())
valid_len = cur_len - past_length
max_len = cache_position.shape[0]
if valid_len < max_len:
cache_position = cache_position[:valid_len]
- # FIXME: padding with zeros might be problematic, in case cache_position[-1] denotes the valid length
cache_position = mint.cat([cache_position, mint.zeros(max_len - valid_len, dtype=ms.int32)])
model_kwargs["cache_position"] = cache_position
@@ -1775,6 +1989,9 @@ def _get_cache(
Returns the resulting cache object.
"""
+ if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
+ cache_implementation = "hybrid_chunked"
+
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
@@ -1790,9 +2007,8 @@ def _get_cache(
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != batch_size
+ or cache_to_check.max_cache_len < max_cache_len
)
- if cache_implementation != "mamba":
- need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
@@ -1812,6 +2028,9 @@ def _get_cache(
"max_cache_len": max_cache_len,
"dtype": cache_dtype,
}
+ if cache_implementation in ["static", "hybrid", "offloaded_static"]:
+ cache_kwargs.update({"tp_size": self.tp_size})
+
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
@@ -1821,44 +2040,47 @@ def _get_cache(
self._cache.reset()
return self._cache
- def _supports_default_dynamic_cache(self) -> bool:
+ @classmethod
+ def _supports_default_dynamic_cache(cls) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
- This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which
- uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in
- order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed
- for `HybridMambaAttentionDynamicCache`).
+ This adds exception for some models like `Mamba` models which use their own caches
+ and do not need to initialize the Cache in advance in order to save memory (because no back and forth
+ `to_legacy_cache` and `from_legacy_cache` will be performed for mamba-based models).
"""
- return (
- self._supports_cache_class
- and "jamba" not in self.__class__.__name__.lower()
- and "zamba" not in self.__class__.__name__.lower()
- and "bamba" not in self.__class__.__name__.lower()
- and "minimax" not in self.__class__.__name__.lower()
+ # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
+ return not cls._is_stateful and all(
+ special_model_name not in cls.__name__.lower()
+ for special_model_name in [
+ "reformer",
+ "minimax",
+ "xlnet",
+ "lfm2",
+ ]
)
- def _supports_default_dynamic_input(self) -> bool:
+ def _supports_default_jit(self) -> bool:
"""
- Return `True` if current model use dynamic cache or _supports_dynamic_input is `True` , but add exception for `paged_attention` which use dynamic input
- shape.
+ Return `True` if current model use compilable cache or _supports_jit is `True` ,
+ and add addtional consideration for `paged_attention` which use dynamic input shape.
"""
return (
- self._supports_dynamic_input
- or self._supports_default_dynamic_cache()
- or self.config._attn_implementation == "paged_attention"
+ self._supports_jit
+ and not self._supports_default_dynamic_cache()
+ and self.config._attn_implementation != "paged_attention"
)
def _prepare_legacy_cache(
self,
generation_config: GenerationConfig,
- model_kwargs: Dict,
+ model_kwargs: dict,
cache_name: str,
batch_size: int,
):
"""
Prepares a static legacy cache (tuple of tuples) for `generate`.
"""
- if self._supports_default_dynamic_input(): # cache will be default None and will be processed further in model
+ if not self._supports_default_jit(): # cache will be default None and will be processed further in model
return
past = model_kwargs.get(cache_name, None)
@@ -1890,7 +2112,7 @@ def _prepare_legacy_cache(
def _prepare_cache_for_generation(
self,
generation_config: GenerationConfig,
- model_kwargs: Dict,
+ model_kwargs: dict,
assistant_model: "PreTrainedModel",
batch_size: int,
max_cache_length: int,
@@ -1899,8 +2121,10 @@ def _prepare_cache_for_generation(
Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
instantiated, writes it to `model_kwargs`, under the name expected by the model.
"""
-
+ # fixme is_hybrid_cache is never used
+ # is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
+
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
@@ -1958,6 +2182,9 @@ def _prepare_cache_for_generation(
)
generation_config.cache_implementation = None
+ generation_config.cache_implementation = generation_config.cache_implementation or getattr(
+ self.config.get_text_config(decoder=True), "cache_implementation", None
+ )
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
@@ -2085,13 +2312,13 @@ def _padding_inputs(
emb_length = inputs_embeds.shape[-1] if inputs_embeds is not None else 0
ignore_label_index = 0
- padded_input_ids = ops.zeros((bs, max_length), ms.int32)
+ padded_input_ids = mint.zeros((bs, max_length), dtype=ms.int32)
padded_labels = ops.full((bs, max_length), ignore_label_index, dtype=ms.int32)
- padded_position_ids = ops.zeros((bs, max_length), ms.int32)
- padded_attention_mask = ops.zeros((bs, max_length), ms.bool_)
+ padded_position_ids = mint.zeros((bs, max_length), dtype=ms.int32)
+ padded_attention_mask = mint.zeros((bs, max_length), dtype=ms.bool_)
padded_inputs_embeds = (
- ops.zeros((bs, max_length, emb_length), inputs_embeds.dtype) if inputs_embeds is not None else None
+ mint.zeros((bs, max_length, emb_length), dtype=inputs_embeds.dtype) if inputs_embeds is not None else None
)
_labels = labels
@@ -2099,15 +2326,15 @@ def _padding_inputs(
if attention_mask is None:
if inputs_embeds is not None:
- attention_mask = ops.ones(inputs_embeds.shape[:2], dtype=ms.bool_)
+ attention_mask = mint.ones(inputs_embeds.shape[:2], dtype=ms.bool_)
else:
- attention_mask = ops.ones(input_ids.shape[:], dtype=ms.bool_)
+ attention_mask = mint.ones(input_ids.shape[:], dtype=ms.bool_)
else:
attention_mask = attention_mask.astype(ms.bool_)
cur_len = int(attention_mask.sum(-1).max())
if position_ids is None:
- position_ids = ops.arange(0, cur_len, dtype=ms.int32)
+ position_ids = mint.arange(0, cur_len, dtype=ms.int32)
if labels is None:
labels = ops.full(
(
@@ -2124,7 +2351,7 @@ def _padding_inputs(
padded_attention_mask[batch_idx, :cur_len] = attention_mask[batch_idx][:]
padded_input_ids[batch_idx, : min(cur_len, input_ids[batch_idx].shape[0])] = input_ids[batch_idx][:]
padded_labels[batch_idx, :cur_len] = labels[batch_idx][:]
- padded_position_ids[batch_idx, :cur_len] = ops.arange(0, cur_len, dtype=position_ids.dtype)
+ padded_position_ids[batch_idx, :cur_len] = mint.arange(0, cur_len.item(), dtype=position_ids.dtype)
if inputs_embeds is not None:
padded_inputs_embeds[batch_idx, :cur_len] = inputs_embeds[batch_idx][:]
@@ -2138,19 +2365,48 @@ def _padding_inputs(
return new_input_ids, new_inputs_embeds, new_labels, new_position_ids, new_attention_mask
+ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: GenerationConfig) -> bool:
+ """
+ Determines whether to trigger auto-compilation of the model's forward pass at generation time.
+ """
+ # Override: honor `disable_compile` flag
+ if generation_config.disable_compile:
+ return False
+
+ # Base logic
+ valid_hardware = ms.get_context("mode") == 0 or bool(
+ generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
+ )
+ using_compilable_cache = (
+ isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
+ )
+ # TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile)
+ can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph
+
+ # Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
+ # them
+ if generation_config.compile_config is not None and not can_compile:
+ logger.warning_once(
+ "You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
+ "will be skipped."
+ )
+
+ return can_compile
+
def generate(
self,
inputs: Optional[ms.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, ms.Tensor], List[int]]] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, ms.Tensor], list[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[ms.Tensor] = None,
negative_prompt_attention_mask: Optional[ms.Tensor] = None,
use_model_defaults: Optional[bool] = None,
+ custom_generate: Optional[str] = None,
**kwargs,
) -> Union[tuple, ms.Tensor]:
r"""
@@ -2191,13 +2447,13 @@ def generate(
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
intended for advanced users.
- prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`, *optional*):
+ prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], list[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
- Retrieval](https://arxiv.org/abs/2010.00904).
+ Retrieval](https://huggingface.co/papers/2010.00904).
synced_gpus (`bool`, *optional*):
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
@@ -2220,7 +2476,12 @@ def generate(
generation configuration (`model.generation_config`), as opposed to the global defaults
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
`True`.
- kwargs (`Dict[str, Any]`, *optional*):
+ custom_generate (`str`, *optional*):
+ A string containing the name of a huggingface.co repository. If provided, the custom `generate`
+ function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
+ standard `generate` method. Note that the logic is for generation is entirely defined in that
+ repository, and the return type may be different from the standard `generate` method.
+ kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
@@ -2241,9 +2502,28 @@ def generate(
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
+ # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ if custom_generate is not None:
+ # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
+ # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
+ # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
+ global_keys_to_exclude = {
+ "self",
+ "kwargs",
+ "global_keys_to_exclude",
+ "trust_remote_code",
+ "custom_generate",
+ }
+ generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
+ generate_arguments.update(kwargs)
+
+ custom_generate_function = self.load_custom_generate(
+ custom_generate, trust_remote_code=trust_remote_code, **kwargs
+ )
+ return custom_generate_function(model=self, **generate_arguments)
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
- self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
@@ -2279,7 +2559,7 @@ def generate(
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
- and ops.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
+ and mint.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
@@ -2325,7 +2605,7 @@ def generate(
streamer.put(input_ids.asnumpy())
# 6. Prepare `max_length` depending on other stopping criteria.
- input_ids_length = input_ids.shape[-1]
+ input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
@@ -2338,8 +2618,9 @@ def generate(
)
# This lines will always select the last logits, which is not the right way for static shape
- # if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
- # model_kwargs["logits_to_keep"] = 1
+ if not self._supports_default_jit():
+ if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
+ model_kwargs["logits_to_keep"] = 1
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
@@ -2386,12 +2667,92 @@ def generate(
# 10. go into different generation modes
if generation_mode == GenerationMode.ASSISTED_GENERATION:
- raise NotImplementedError
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ "num_return_sequences has to be 1 when doing assisted generate, "
+ f"but is {generation_config.num_return_sequences}."
+ )
+ if batch_size > 1:
+ raise ValueError("assisted generate is only supported for batch_size = 1")
+ if not model_kwargs["use_cache"]:
+ raise ValueError("assisted generate requires `use_cache=True`")
+ if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
+ raise ValueError("assisted generate is not supported with Static cache classes`")
+ if self._is_stateful:
+ # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
+ # which is not possible with stateful models (they can't reset to a previous subset of generated text)
+ raise ValueError(
+ f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
+ )
+
+ # 11. Get the candidate generator, given the parameterization
+ candidate_generator = self._get_candidate_generator(
+ generation_config=generation_config,
+ input_ids=input_ids,
+ inputs_tensor=inputs_tensor,
+ assistant_model=assistant_model,
+ logits_processor=logits_processor,
+ target_tokenizer=tokenizer,
+ assistant_tokenizer=assistant_tokenizer,
+ model_kwargs=model_kwargs,
+ )
+
+ # 12. run assisted generate
+ result = self._assisted_decoding(
+ input_ids,
+ candidate_generator=candidate_generator,
+ logits_processor=prepared_logits_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
elif generation_mode == GenerationMode.DOLA_GENERATION:
- raise NotImplementedError
+ if not trust_remote_code:
+ logger.warning_once(
+ "DoLa Decoding is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
+ "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
+ )
+ if self._is_stateful:
+ # DoLa decoding was not designed for stateful models, and would require some changes
+ raise ValueError(
+ f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
+ )
+ result = self._dola_decoding(
+ input_ids,
+ dola_layers=generation_config.dola_layers,
+ logits_processor=prepared_logits_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
- raise NotImplementedError
+ if not trust_remote_code:
+ logger.warning_once(
+ "Contrastive Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
+ "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
+ )
+ if not model_kwargs["use_cache"]:
+ raise ValueError("Contrastive search requires `use_cache=True`")
+ if self._is_stateful:
+ # Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
+ raise ValueError(
+ f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
+ )
+
+ result = self._contrastive_search(
+ input_ids,
+ logits_processor=prepared_logits_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
@@ -2431,7 +2792,110 @@ def generate(
**model_kwargs,
)
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
- raise NotImplementedError
+ logger.warning_once(
+ "Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
+ "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
+ )
+ # 11. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ num_beam_groups=generation_config.num_beam_groups,
+ max_length=generation_config.max_length,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ result = self._group_beam_search(
+ input_ids,
+ beam_scorer,
+ logits_processor=prepared_logits_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
+ logger.warning_once(
+ "Constrained Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
+ "To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
+ )
+ final_constraints = []
+ if generation_config.constraints is not None:
+ final_constraints = generation_config.constraints
+
+ if generation_config.force_words_ids is not None:
+
+ def typeerror():
+ raise ValueError(
+ "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` "
+ f"of positive integers, but is {generation_config.force_words_ids}."
+ )
+
+ if (
+ not isinstance(generation_config.force_words_ids, list)
+ or len(generation_config.force_words_ids) == 0
+ ):
+ typeerror()
+
+ for word_ids in generation_config.force_words_ids:
+ if isinstance(word_ids[0], list):
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any(not isinstance(token_ids, list) for token_ids in word_ids):
+ typeerror()
+ if any(
+ any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
+ for token_ids in word_ids
+ ):
+ typeerror()
+
+ constraint = DisjunctiveConstraint(word_ids)
+ else:
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
+ typeerror()
+
+ constraint = PhrasalConstraint(word_ids)
+ final_constraints.append(constraint)
+
+ # 11. prepare beam search scorer
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
+ constraints=final_constraints,
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ max_length=generation_config.max_length,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ result = self._constrained_beam_search(
+ input_ids,
+ constrained_beam_scorer=constrained_beam_scorer,
+ logits_processor=prepared_logits_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
# Convert to legacy cache format if requested
if (
@@ -2482,7 +2946,7 @@ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedToken
# assumption: leading/trailing whitespace is not meaningful, so the prompts are
# stripped before re-tokenizing to desensitize generation to whitespace artefacts
prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
- input_ids = ms.Tensor(
+ input_ids = ms.tensor(
tokenizer(
prompts,
return_tensors="np",
@@ -2491,7 +2955,14 @@ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedToken
)
# replace bos with pad to not condition healing on it
- input_ids = ops.where(input_ids == bos_token_id, pad_token_id, input_ids)
+ input_ids = mint.where(input_ids == bos_token_id, pad_token_id, input_ids)
+
+ """
+ the latter code assumes the input_ids is not empty,
+ input_id has to be checked if contains elements
+ """
+ if input_ids.numel() == 0:
+ return input_ids
tail_ids = input_ids[:, -1].tolist()
@@ -2502,11 +2973,18 @@ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedToken
for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
batch_ids = input_ids[batch_idx]
- if ops.all(batch_ids == pad_token_id).item():
+ if mint.all(batch_ids == pad_token_id).item():
continue # skip empty sequences (all pad ids)
# apply bias for alternatives (extensions) to the tail token
- seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
+ """
+ seq_bias key has to be tuple with int so have to use
+ tokenizer function to convert str to int
+ """
+ seq_bias = {
+ (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
+ }
+
if len(seq_bias) == 1:
continue # skip if there are no token alternatives to heal with
@@ -2531,49 +3009,60 @@ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedToken
return input_ids
- def _sample(
+ def _dola_decoding(
self,
input_ids: ms.Tensor,
+ dola_layers: Union[str, list[int]],
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
- streamer: Optional["BaseStreamer"],
+ streamer: "BaseStreamer",
**model_kwargs,
) -> Union[GenerateNonBeamOutput, ms.Tensor]:
r"""
- Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
- can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+ Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be
+ used for decoder-only text models.
+ The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language
+ Models" (https://huggingface.co/papers/2309.03883) in ICLR 2024.
Parameters:
input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
+ dola_layers (`Union[str, list[int]]`):
+ The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which
+ means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices
+ to be used for candidate layers. The 0-th layer is the word embedding layer of the model.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`):
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
+ Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
+ If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
- [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `ms.Tensor`:
- A `ms.Tensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`]
+ or `ms.Tensor`: A `ms.Tensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
+ if self.config.is_encoder_decoder:
+ raise ValueError("DoLa decoding is only available for decoder-only models.")
# init values
+
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
@@ -2590,92 +3079,78 @@ def _sample(
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
- )
-
- # Padding inputs to avoid dynamic shape
- if not self._supports_default_dynamic_input():
- (
- padded_input_ids,
- padded_inputs_embeds,
- padded_labels,
- padded_position_ids,
- padded_attention_mask,
- ) = self._padding_inputs(
- generation_config,
- input_ids,
- model_kwargs.get("inputs_embeds", None),
- model_kwargs.get("labels", None),
- model_kwargs.get("position_ids", None),
- model_kwargs.get("attention_mask", None),
- )
- input_ids = padded_input_ids
- model_kwargs["attention_mask"] = padded_attention_mask
- if model_kwargs.get("inputs_embeds", None) is not None:
- model_kwargs["inputs_embeds"] = padded_inputs_embeds
- if model_kwargs.get("labels", None) is not None:
- model_kwargs["labels"] = padded_labels
- if model_kwargs.get("position_ids", None) is not None:
- model_kwargs["position_ids"] = padded_position_ids
-
# keep track of which sequences are already finished
- batch_size, cur_len = input_ids.shape
+ batch_size, cur_length = input_ids.shape[:2]
+ unfinished_sequences = mint.ones(batch_size, dtype=ms.int64)
+ model_kwargs = self._get_initial_cache_position(cur_length, model_kwargs)
+
this_peer_finished = False
- unfinished_sequences = ops.ones(batch_size, dtype=ms.int32)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
- multinomial = get_multinomial_op()
- step = 0
- s_time = time.time()
- graph_compiled_time_buffer = []
+ # prepare layers for DoLa decoding
+ final_layer = self.config.get_text_config().num_hidden_layers
+ # if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
+ # as the early exit from word embeddings will become identity function
+ # if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
+ # layer otherwise. Notice that DoLa does not help shallow models much.
+ if not self.config.tie_word_embeddings:
+ start_layer = 0
+ elif final_layer > 2:
+ start_layer = 2
+ elif final_layer == 2:
+ start_layer = 1
+ else:
+ start_layer = 0
+
+ # For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)`
+ # are used for `'low'` and `'high'` layers, respectively.
+ # For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for
+ # `'low'` and `'high'` layers, respectively.
+ if isinstance(dola_layers, str) and dola_layers == "low":
+ if start_layer == final_layer // 2:
+ candidate_premature_layers = [start_layer]
+ else:
+ candidate_premature_layers = (
+ list(range(start_layer, final_layer // 2, 2))
+ if final_layer <= 40
+ else list(range(start_layer, 20, 2))
+ )
+ elif isinstance(dola_layers, str) and dola_layers == "high":
+ candidate_premature_layers = (
+ list(range(final_layer // 2, final_layer, 2))
+ if final_layer <= 40
+ else list(range(final_layer - 20, final_layer, 2))
+ )
+ # Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers.
+ elif isinstance(dola_layers, list):
+ candidate_premature_layers = [i for i in dola_layers if i < final_layer]
+ else:
+ raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.")
+
+ lm_head = self.get_output_embeddings()
+ if lm_head is None:
+ raise ValueError("DoLa is not supported for models that don't have output embeddings.")
while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
# prepare model inputs
-
- if input_ids.dtype == ms.int64:
- input_ids = input_ids.to(ms.int32)
-
- if self.config._attn_implementation == "paged_attention":
- model_kwargs["step"] = step
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- # prepare variable output controls (note: some models won't accept all output controls)
- # Note that this is slightly different from hf transformers, since page attention dose not accept
- # undeterminante None/bool inputs
- model_inputs.update({"output_attentions": output_attentions})
- model_inputs.update({"output_hidden_states": output_hidden_states})
-
# forward pass to get next token
outputs = self(
**model_inputs,
- return_dict=False if ms.get_context("mode") == ms.GRAPH_MODE else True,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
)
- if not isinstance(outputs, ModelOutput):
- outputs = ModelOutput(
- loss=None,
- logits=outputs[0],
- past_key_values=outputs[1] if model_inputs.get("use_cache", False) else None,
+
+ # .float() is needed to retain precision for later logits manipulations
+ final_layer_next_token_logits = outputs.logits[:, -1, :].detach().to(dtype=ms.float32)
+ final_logits = outputs.logits[:, -1, :].float()
+ candidate_premature_logits = {}
+ for candidate_premature_layer in candidate_premature_layers:
+ candidate_premature_logits[candidate_premature_layer] = lm_head(
+ outputs.hidden_states[candidate_premature_layer][:, -1, :]
)
- if self._supports_default_dynamic_input() or model_kwargs.get("attention_mask", None) is None:
- next_token_logits = outputs.logits[:, -1, :]
- else: # Get the right logits from static input shape
- attention_mask = model_kwargs["attention_mask"]
- cur_idx = int(attention_mask.sum(-1).max()) - 1
-
- if outputs.logits.shape[1] == attention_mask.shape[-1]:
- next_token_logits = outputs.logits[:, cur_idx, :] # (bs, seq, dim)
- else:
- next_token_logits = outputs.logits[:, -1, :]
-
- # `input_ids` obtain effective length after 1st step
- if input_ids.shape[1] == attention_mask.shape[1]:
- input_ids = input_ids[:, : cur_idx + 1]
-
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
@@ -2685,18 +3160,9 @@ def _sample(
if synced_gpus and this_peer_finished:
continue
- step_time = time.time() - s_time
- if step < 2:
- print(f"==> sampling, step: {step}, time cost: {step_time:.5f}s")
- else:
- graph_compiled_time_buffer.append(step_time)
- token_speed = len(graph_compiled_time_buffer) / sum(graph_compiled_time_buffer)
- print(
- f"==> sampling, step: {step}, time cost: {step_time:.5f}s, running avg speed: {token_speed:.5f}token/s"
- )
- s_time = time.time()
- step += 1
-
+ next_token_logits = _dola_select_contrast(
+ candidate_premature_layers, candidate_premature_logits, final_logits
+ )
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
@@ -2705,7 +3171,7 @@ def _sample(
if output_scores:
scores += (next_token_scores,)
if output_logits:
- raw_logits += (next_token_logits,)
+ raw_logits += (final_layer_next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
@@ -2718,35 +3184,424 @@ def _sample(
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
)
- # token selection
- if do_sample:
- probs = ops.softmax(next_token_scores, axis=-1, dtype=ms.float32).to(next_token_scores.dtype)
- next_tokens = multinomial(probs, num_samples=1).squeeze(1)
- else:
- next_tokens = ops.argmax(next_token_scores, dim=-1)
+ if do_sample: # sample
+ probs = mint.nn.functional.softmax(next_token_scores, dim=-1)
+ next_tokens = mint.nn.functional.multinomial(probs, num_samples=1).squeeze(1)
+ else: # argmax
+ next_tokens = mint.argmax(next_token_scores, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
- next_tokens = next_tokens.to(ms.int32)
# update generated ids, model inputs, and length for next step
- input_ids = ops.cat([input_ids, next_tokens[:, None]], axis=-1)
+ input_ids = mint.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
- streamer.put(next_tokens.asnumpy())
+ streamer.put(next_tokens.cpu())
- unfinished_sequences = unfinished_sequences & ~ms.Tensor(stopping_criteria(input_ids, scores), ms.bool_)
+ # stop when each sentence is finished
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
- cur_len += 1
- # This is needed to properly delete outputs.logits which may be very large for first iteration
- # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return input_ids
+
+ def _contrastive_search(
+ self,
+ input_ids: ms.Tensor,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool,
+ streamer: Optional["BaseStreamer"],
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, ms.Tensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
+ be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+ Parameters:
+ input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ generation_config ([`~generation.GenerationConfig`]):
+ The generation configuration to be used as parametrization of the decoding method.
+ synced_gpus (`bool`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ model_kwargs:
+ Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
+ If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`]
+ or `ms.Tensor`: A `ms.Tensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+ """
+ # init values
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+ top_k = generation_config.top_k
+ penalty_alpha = generation_config.penalty_alpha
+ pad_token_id = generation_config._pad_token_tensor
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
+ sequential = generation_config.low_memory
+
+ # init attention / hidden states / scores tuples
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
+
+ # keep track of which sequences are already finished
+ batch_size, cur_len = input_ids.shape[:2]
+ unfinished_sequences = mint.ones(batch_size, dtype=ms.int64)
+ model_kwargs = self._get_initial_cache_position(cur_len, model_kwargs)
+
+ # Create cosine_matrix_mask based on the attention_mask
+ cosine_matrix_mask = mint.ones_like(input_ids, dtype=ms.int64)
+ if self.config.is_encoder_decoder:
+ if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
+ cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
+ else:
+ cosine_matrix_mask = model_kwargs["attention_mask"]
+ cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0)
+
+ this_peer_finished = False
+
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
+ # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
+ # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
+ if model_kwargs.get("past_key_values") is None or (
+ isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
+ and model_kwargs["past_key_values"].get_seq_length() == 0
+ ):
+ # prepare inputs
+ model_kwargs["use_cache"] = True
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
+ # the `encoder_outputs`
+ outputs = self(
+ **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
+ )
+
+ # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
+ # previous tokens)
+ if self.config.is_encoder_decoder:
+ last_hidden_states = outputs.decoder_hidden_states[-1]
+ else:
+ last_hidden_states = outputs.hidden_states[-1]
+
+ # next logit for contrastive search to select top-k candidate tokens
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
+ # (the clone itself is always small)
+ # torch.float32 is needed to retain precision for later logits manipulations
+ logit_for_next_step = outputs.logits[:, -1, :].to(copy=True, dtype=ms.float32)
+
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+
+ if not sequential:
+ # Expands model inputs top_k times, for batched forward passes (akin to beam search).
+ # input_ids is required for expanding visual inputs in qwen2vl
+ _, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=top_k,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ past_key_values = model_kwargs.get("past_key_values")
+ if past_key_values is None:
+ raise ValueError(
+ f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
+ "for contrastive search."
+ )
+ elif (
+ not isinstance(past_key_values[0], (tuple, ms.Tensor))
+ or past_key_values[0][0].shape[0] != batch_size
+ ):
+ raise ValueError(
+ f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
+ "used for contrastive search without further modifications."
+ )
+
+ # contrastive_search main logic start:
+ # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
+ # degeneration penalty
+ processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
+ next_probs = mint.nn.functional.softmax(processed_logit_for_next_step, dim=-1)
+
+ top_k_probs, top_k_ids = mint.topk(next_probs, dim=-1, k=top_k)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_logits:
+ raw_logits += (logit_for_next_step,)
+ if output_scores:
+ scores += (processed_logit_for_next_step,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
+ )
+
+ # This is needed to properly delete outputs.logits which may be very large for this first iteration
+ # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward()
del outputs
+ if not sequential:
+ # Replicates the new past_key_values to match the `top_k` candidates
+ past = model_kwargs["past_key_values"]
+ # If it is a static cache, modify it in-place layer after layer to save memory
+ if isinstance(past, DynamicCache) or (
+ isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
+ ):
+ past.batch_repeat_interleave(top_k)
+ else:
+ new_key_values = []
+ for layer in past:
+ items = []
+ # item is either the key or the value matrix
+ for item in layer:
+ items.append(item.repeat_interleave(top_k, dim=0))
+ new_key_values.append(tuple(items))
+
+ past = tuple(new_key_values)
+
+ model_kwargs["past_key_values"] = past
+
+ if sequential:
+ all_outputs = []
+ for i in range(top_k):
+ # compute the candidate tokens by the language model and collect their hidden_states
+ next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
+
+ outputs = self(
+ **next_model_inputs,
+ return_dict=True,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ )
+ if isinstance(outputs["past_key_values"], DynamicCache) or (
+ isinstance(outputs["past_key_values"], EncoderDecoderCache)
+ and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
+ ):
+ # Remove past K-V from output since we don't need to stack later
+ outputs["past_key_values"] = None
+ # Remove last token from past K-V since we don't want to append it at this point
+ model_kwargs["past_key_values"].crop(-1)
+
+ all_outputs.append(outputs)
+ outputs = stack_model_outputs(all_outputs, self.config.get_text_config())
+
+ else:
+ # compute the candidate tokens by the language model and collect their hidden_states
+ # assembles top_k_ids into batch of size k
+ next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
+
+ outputs = self(
+ **next_model_inputs,
+ return_dict=True,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ )
+
+ # This is essential to avoid having a last reference to the big past K-V and double the necessary memory
+ # in the next loop
+ del next_model_inputs
+
+ # name is different for encoder-decoder and decoder-only models
+ if self.config.is_encoder_decoder:
+ next_hidden = outputs.decoder_hidden_states[-1]
+ full_hidden_states = outputs.decoder_hidden_states
+ else:
+ next_hidden = outputs.hidden_states[-1]
+ full_hidden_states = outputs.hidden_states
+
+ # .float() is needed to retain precision for later logits manipulations
+ logits = outputs.logits[:, -1, :].float()
+ context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
+
+ # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
+ # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
+ # introduce (noticeable) slowdowns on single-device runs.
+ selected_idx = _ranking_fast(
+ context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k
+ )
+ cosine_matrix_mask = mint.cat(
+ [cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1
+ )
+
+ # This will be used instead of the previous inneficient torch.stack(torch.split())
+ augmented_idx = ms.Tensor([x + i * top_k for i, x in enumerate(selected_idx)])
+
+ # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
+ # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
+ # (model confidence minus degeneration penalty); (6) decoder hidden_states
+ next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
+ next_hidden = mint.stack(mint.split(next_hidden.squeeze(dim=1), top_k))
+ next_hidden = next_hidden[range(batch_size), selected_idx, :]
+ last_hidden_states = mint.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
+
+ next_decoder_hidden_states = ()
+ for layer in full_hidden_states:
+ layer = mint.stack(mint.split(layer, top_k))[range(batch_size), selected_idx, :]
+ next_decoder_hidden_states += (layer,)
+
+ # generate past_key_values cache of only the selected token
+ if sequential:
+ next_model_input = self.prepare_inputs_for_generation(
+ top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
+ )
+
+ selected_outputs = self(
+ **next_model_input,
+ return_dict=True,
+ output_hidden_states=False,
+ output_attentions=False,
+ )
+ next_past_key_values = selected_outputs["past_key_values"]
+
+ else:
+ next_past_key_values = None
+ for possible_cache_name in ALL_CACHE_NAMES:
+ next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
+ # Do it in-place layer per layer to save memory
+ if isinstance(next_past_key_values, DynamicCache) or (
+ isinstance(next_past_key_values, EncoderDecoderCache)
+ and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
+ ):
+ next_past_key_values.batch_select_indices(augmented_idx)
+ else:
+ new_key_values = []
+ for layer in next_past_key_values:
+ items = []
+ # item is either the key or the value matrix
+ for item in layer:
+ items.append(item[augmented_idx, ...])
+ new_key_values.append(tuple(items))
+
+ next_past_key_values = tuple(new_key_values)
+
+ logit_for_next_step = mint.stack(mint.split(logits, top_k))[range(batch_size), selected_idx, :]
+
+ # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
+ if self.config.is_encoder_decoder:
+ next_step_cross_attentions = ()
+ next_step_decoder_attentions = ()
+ if output_attentions:
+ for layer in outputs.cross_attentions:
+ layer = mint.stack(mint.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
+ next_step_cross_attentions += (layer,)
+ for layer in outputs.decoder_attentions:
+ layer = mint.stack(mint.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
+ next_step_decoder_attentions += (layer,)
+ outputs = Seq2SeqLMOutput(
+ past_key_values=next_past_key_values,
+ decoder_hidden_states=next_decoder_hidden_states,
+ decoder_attentions=next_step_decoder_attentions or None,
+ cross_attentions=next_step_cross_attentions or None,
+ )
+ else:
+ next_step_attentions = ()
+ if output_attentions:
+ for layer in outputs.attentions:
+ layer = mint.stack(mint.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
+ next_step_attentions += (layer,)
+ outputs = CausalLMOutputWithPast(
+ past_key_values=next_past_key_values,
+ hidden_states=next_decoder_hidden_states,
+ attentions=next_step_attentions or None,
+ )
+ # contrastive_search main logic end
+
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+ if synced_gpus and this_peer_finished:
+ continue
+
+ # finished sentences should have their next token be a padding token
+ if has_eos_stopping_criteria:
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = mint.cat([input_ids, next_tokens[:, None]], dim=-1)
+ if streamer is not None:
+ streamer.put(next_tokens.cpu())
+
+ # stop when each sentence is finished
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
+ this_peer_finished = unfinished_sequences.max() == 0
+
if streamer is not None:
streamer.end()
if return_dict_in_generate:
+ # Contrastive search works by forward looking at the next token, so we need to exclude it from
+ # `past_key_values` to be consistent with the other decoding methods
+ if model_kwargs.get("past_key_values") is not None:
+ if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
+ isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
+ and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
+ ):
+ model_kwargs["past_key_values"].crop(-1)
+ else:
+ past_key_values = []
+ for layer in model_kwargs["past_key_values"]:
+ layer_past_key_values = []
+ for item in layer:
+ layer_past_key_values.append(item[..., :-1, :])
+ past_key_values.append(tuple(layer_past_key_values))
+ model_kwargs["past_key_values"] = tuple(past_key_values)
+
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
@@ -2771,235 +3626,1176 @@ def _sample(
else:
return input_ids
- # Auxiliary functions for beam search
- def _temporary_reorder_cache(self, past_key_values, beam_idx):
- """
- Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
- TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need
- for this function, with `Cache.reorder_cache` being the sole remaining code path
- """
- model_class = self.__class__.__name__.lower()
- # Exception 1: code path for models using the legacy cache format
- if isinstance(past_key_values, (tuple, list)):
- past_key_values = self._reorder_cache(past_key_values, beam_idx)
- # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
- # cache format is standardized, to avoid adding complexity to the codebase.
- elif "gptbigcode" in model_class:
- if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
- raise ValueError(
- f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
- "legacy tuple format or `DynamicCache`"
- )
- past_key_values = self._reorder_cache(past_key_values, beam_idx)
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
- # Standard code path: use the `Cache.reorder_cache`
- else:
- past_key_values.reorder_cache(beam_idx)
- return past_key_values
-
- @staticmethod
- def _flatten_beam_dim(tensor: ms.Tensor) -> ms.Tensor:
- """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
- shape = list(tensor.shape)
- return mint.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
-
- @staticmethod
- def _unflatten_beam_dim(tensor: ms.Tensor, batch_size: int, num_beams: int) -> ms.Tensor:
- """[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]"""
- shape = list(tensor.shape)
- return mint.reshape(tensor, [batch_size, num_beams] + shape[1:])
+ def _sample(
+ self,
+ input_ids: ms.Tensor,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool,
+ streamer: Optional["BaseStreamer"],
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, ms.Tensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
- @staticmethod
- def _gather_beams(tensor: ms.Tensor, beam_indices: ms.Tensor) -> ms.Tensor:
- """
- Gathers the beam slices indexed by beam_indices into new beam array.
- Args:
- tensor (`ms.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor
- with the two first dimensions depicting the batch and the beam dimensions.
- beam_indices (`ms.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to
- select .
- Returns:
+ Parameters:
+ input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ generation_config ([`~generation.GenerationConfig`]):
+ The generation configuration to be used as parametrization of the decoding method.
+ synced_gpus (`bool`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `ms.Tensor`:
+ A `ms.Tensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+ """
+
+ # init values
+ pad_token_id = generation_config._pad_token_tensor
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+ do_sample = generation_config.do_sample
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
+
+ # Padding inputs to avoid dynamic shape
+ if self._supports_default_jit():
+ (
+ padded_input_ids,
+ padded_inputs_embeds,
+ padded_labels,
+ padded_position_ids,
+ padded_attention_mask,
+ ) = self._padding_inputs(
+ generation_config,
+ input_ids,
+ model_kwargs.get("inputs_embeds", None),
+ model_kwargs.get("labels", None),
+ model_kwargs.get("position_ids", None),
+ model_kwargs.get("attention_mask", None),
+ )
+ input_ids = padded_input_ids
+ model_kwargs["attention_mask"] = padded_attention_mask
+ if model_kwargs.get("inputs_embeds", None) is not None:
+ model_kwargs["inputs_embeds"] = padded_inputs_embeds
+ if model_kwargs.get("labels", None) is not None:
+ model_kwargs["labels"] = padded_labels
+ if model_kwargs.get("position_ids", None) is not None:
+ model_kwargs["position_ids"] = padded_position_ids
+
+ # keep track of which sequences are already finished
+ batch_size, cur_len = input_ids.shape
+ this_peer_finished = False
+ unfinished_sequences = mint.ones(batch_size, dtype=ms.int32)
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+
+ model_forward = self.__call__
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
+ if compile_forward:
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
+ # If we use FA2 and a static cache, we cannot compile with fullgraph
+ if self.config._attn_implementation == "flash_attention_2" and getattr(
+ model_kwargs.get("past_key_values"), "is_compileable", False
+ ):
+ if generation_config.compile_config is None:
+ generation_config.compile_config = CompileConfig(fullgraph=False)
+ # only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
+ elif generation_config.compile_config.fullgraph:
+ logger.warning_once(
+ "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
+ "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
+ )
+ generation_config.compile_config.fullgraph = False
+ model_forward = self.get_compiled_call(generation_config.compile_config)
+
+ if generation_config.prefill_chunk_size is not None:
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
+ is_prefill = False
+ else:
+ is_prefill = True
+
+ multinomial = get_multinomial_op()
+ step = 0
+ s_time = time.time()
+ graph_compiled_time_buffer = []
+
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
+ # prepare model inputs
+
+ if input_ids.dtype == ms.int64:
+ input_ids = input_ids.to(ms.int32)
+
+ if self.config._attn_implementation == "paged_attention":
+ model_kwargs["step"] = step
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # prepare variable output controls (note: some models won't accept all output controls)
+ # Note that this is slightly different from hf transformers, since page attention dose not accept
+ # undeterminante None/bool inputs
+ model_inputs.update({"output_attentions": output_attentions})
+ model_inputs.update({"output_hidden_states": output_hidden_states})
+
+ # forward pass to get next token
+ if is_prefill:
+ outputs = self(
+ **model_inputs,
+ return_dict=False if ms.get_context("mode") == ms.GRAPH_MODE else True,
+ )
+ is_prefill = False
+ else:
+ outputs = model_forward(
+ **model_inputs,
+ return_dict=False if ms.get_context("mode") == ms.GRAPH_MODE else True,
+ )
+
+ if not isinstance(outputs, ModelOutput):
+ outputs = ModelOutput(
+ loss=None,
+ logits=outputs[0],
+ past_key_values=outputs[1] if model_inputs.get("use_cache", False) else None,
+ )
+
+ if not self._supports_default_jit() or model_kwargs.get("attention_mask", None) is None:
+ next_token_logits = outputs.logits[:, -1, :]
+ else: # Get the right logits from static input shape
+ attention_mask = model_kwargs["attention_mask"]
+ cur_idx = int(attention_mask.sum(-1).max()) - 1
+
+ if outputs.logits.shape[1] == attention_mask.shape[-1]:
+ next_token_logits = outputs.logits[:, cur_idx, :] # (bs, seq, dim)
+ else:
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # `input_ids` obtain effective length after 1st step
+ if input_ids.shape[1] == attention_mask.shape[1]:
+ input_ids = input_ids[:, : cur_idx + 1]
+
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+ if synced_gpus and this_peer_finished:
+ continue
+
+ step_time = time.time() - s_time
+ if step < 2:
+ print(f"==> sampling, step: {step}, time cost: {step_time:.5f}s")
+ else:
+ graph_compiled_time_buffer.append(step_time)
+ token_speed = len(graph_compiled_time_buffer) / sum(graph_compiled_time_buffer)
+ print(
+ f"==> sampling, step: {step}, time cost: {step_time:.5f}s, running avg speed: {token_speed:.5f}token/s"
+ )
+ s_time = time.time()
+ step += 1
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_logits:
+ raw_logits += (next_token_logits,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
+ )
+
+ # token selection
+ if do_sample:
+ probs = mint.softmax(next_token_scores, dim=-1, dtype=ms.float32).to(next_token_scores.dtype)
+ next_tokens = multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = mint.argmax(next_token_scores, dim=-1)
+
+ # finished sentences should have their next token be a padding token
+ if has_eos_stopping_criteria:
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+ next_tokens = next_tokens.to(ms.int32)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = mint.cat([input_ids, next_tokens[:, None]], dim=-1)
+ if streamer is not None:
+ streamer.put(next_tokens.asnumpy())
+
+ unfinished_sequences = unfinished_sequences & ~ms.Tensor(stopping_criteria(input_ids, scores), ms.bool_)
+ this_peer_finished = unfinished_sequences.max() == 0
+ cur_len += 1
+
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
+ del outputs
+
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ if self.config.is_encoder_decoder:
+ return GenerateEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return input_ids
+
+ # Auxiliary functions for beam search
+ def _temporary_reorder_cache(self, past_key_values, beam_idx):
+ """
+ Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.
+ TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need
+ for this function, with `Cache.reorder_cache` being the sole remaining code path
+ """
+ model_class = self.__class__.__name__.lower()
+ # Exception 1: code path for models using the legacy cache format
+ if isinstance(past_key_values, (tuple, list)):
+ past_key_values = self._reorder_cache(past_key_values, beam_idx)
+ # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
+ # cache format is standardized, to avoid adding complexity to the codebase.
+ elif "gptbigcode" in model_class:
+ if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
+ raise ValueError(
+ f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
+ "legacy tuple format or `DynamicCache`"
+ )
+ past_key_values = self._reorder_cache(past_key_values, beam_idx)
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ # Standard code path: use the `Cache.reorder_cache`
+ else:
+ past_key_values.reorder_cache(beam_idx)
+ return past_key_values
+
+ @staticmethod
+ def _flatten_beam_dim(tensor: ms.Tensor) -> ms.Tensor:
+ """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
+ shape = list(tensor.shape)
+ return mint.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
+
+ @staticmethod
+ def _unflatten_beam_dim(tensor: ms.Tensor, batch_size: int, num_beams: int) -> ms.Tensor:
+ """[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]"""
+ shape = list(tensor.shape)
+ return mint.reshape(tensor, [batch_size, num_beams] + shape[1:])
+
+ @staticmethod
+ def _gather_beams(tensor: ms.Tensor, beam_indices: ms.Tensor) -> ms.Tensor:
+ """
+ Gathers the beam slices indexed by beam_indices into new beam array.
+ Args:
+ tensor (`ms.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor
+ with the two first dimensions depicting the batch and the beam dimensions.
+ beam_indices (`ms.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to
+ select .
+ Returns:
A tensor with the selected beams
"""
- # `take_along_dim` requires its indices arg to have the same number of dims as `input`
- while len(beam_indices.shape) < len(tensor.shape):
- beam_indices = beam_indices.unsqueeze(-1)
- gathered_tensor = ms.Tensor(ms.numpy.take_along_axis(arr=tensor, indices=beam_indices, axis=1))
- return gathered_tensor
+ # `take_along_dim` requires its indices arg to have the same number of dims as `input`
+ while len(beam_indices.shape) < len(tensor.shape):
+ beam_indices = beam_indices.unsqueeze(-1)
+ gathered_tensor = ms.Tensor(ms.numpy.take_along_axis(arr=tensor, indices=beam_indices, axis=1))
+ return gathered_tensor
+
+ @staticmethod
+ def _check_early_stop_heuristic(
+ is_early_stop_heuristic_unsatisfied: ms.Tensor,
+ running_beam_scores: ms.Tensor,
+ beam_scores: ms.Tensor,
+ is_sent_finished: ms.Tensor,
+ cur_len: int,
+ max_length: int,
+ decoder_prompt_len: int,
+ early_stopping: Union[bool, str],
+ length_penalty: float,
+ ):
+ """
+ Determine whether early stopping is possible by checking if the best possible score of running beams
+ could still improve upon the finished ones.
+
+ Mechanism:
+ - Without a length penalty, beam scores typically decrease as more tokens are generated.
+ So, if the *best possible* score from any running beam is already worse than the *worst* finished beam,
+ we can safely stop early.
+ - With a length penalty, scores may increase with longer sequences. In this case, we use heuristics
+ to estimate the best possible score — though this estimate may not always be correct — and stop
+ if no further improvement seems likely.
+
+ We apply different heuristics depending on the value of `early_stopping`:
+ 1. `early_stopping == False`:
+ -> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length.
+ -> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
+
+ 2. `early_stopping == "never"`:
+ -> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`.
+ -> A positive length penalty favors longer sequences, so we use `max_length` in that case.
+
+ NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and
+ `length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce
+ better sequences (prior to 2022), and changing it is BC breaking.
+ """
+ if early_stopping == "never" and length_penalty > 0.0:
+ best_hypothetical_length = max_length - decoder_prompt_len
+ else:
+ best_hypothetical_length = cur_len - decoder_prompt_len
+ best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
+ worst_finished_score = mint.where(is_sent_finished, mint.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
+ return is_early_stop_heuristic_unsatisfied & mint.any(
+ best_possible_running_score > worst_finished_score, dim=-1, keepdim=True
+ )
+
+ @staticmethod
+ def _beam_search_has_unfinished_sequences(
+ running_beam_scores: ms.Tensor,
+ beam_scores: ms.Tensor,
+ is_sent_finished: ms.Tensor,
+ next_token_hits_stopping_criteria: ms.Tensor,
+ cur_len: int,
+ max_length: int,
+ decoder_prompt_len: int,
+ early_stopping: Union[bool, str],
+ length_penalty: float,
+ ):
+ """
+ Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
+ """
+ # a. Can the open beams improve the top completed scores?
+ # early_stopping == False -> apply heuristic = always get the best score from
+ # `cur_len - decoder_prompt_len`. See the discussion below for more details.
+ # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
+ # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
+ # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
+ # `max_length` there.
+ if early_stopping == "never" and length_penalty > 0.0:
+ best_hypothetical_length = max_length - decoder_prompt_len
+ else:
+ best_hypothetical_length = cur_len - decoder_prompt_len
+ best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
+ worst_finished_score = mint.where(is_sent_finished, mint.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
+ improvement_possible = mint.any(best_possible_running_score > worst_finished_score)
+
+ # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
+ # enabled, where we want to finish as soon as all beams have a completed sequence.
+ exists_open_beam = ms.Tensor(
+ ~(mint.all(is_sent_finished) & ms.Tensor(early_stopping is True, ms.int32)), ms.int32
+ )
+
+ # c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have
+ # reached `max_length``
+ valid_continuations = ~mint.all(next_token_hits_stopping_criteria)
+
+ return improvement_possible & exists_open_beam & valid_continuations
+
+ def _get_top_k_continuations(
+ self,
+ accumulated_log_probs: ms.Tensor,
+ running_sequences: ms.Tensor,
+ running_beam_indices: ms.Tensor,
+ cur_len: int,
+ decoder_prompt_len: int,
+ do_sample: bool,
+ beams_to_keep: int,
+ num_beams: int,
+ vocab_size: int,
+ batch_size: int,
+ ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor]:
+ """
+ Get top-K continuations given the accumulated log probs on the next token.
+ A few notes to understand what's going on:
+ 1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the
+ top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated
+ log-probabilities, or sample them without replacement using the accumulated scores
+ 2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at
+ least `num_beams` sequences remaining to continue the live beam search.
+ 3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations
+ selected in this step hit the stopping criteria.
+ """
+ # TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after
+ # token selection. The function should be an argument exposed, so that custom scoring functions can be
+ # defined.
+
+ # Gather the top K scores from _all_ beams.
+ if do_sample:
+ topk_indices = mint.multinomial(
+ mint.nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
+ )
+ topk_log_probs = mint.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
+ else:
+ topk_log_probs, topk_indices = mint.topk(accumulated_log_probs, k=beams_to_keep)
+
+ # Gather K top beams, recover the beam index by floor division and token id by modulo division
+ topk_current_beam_indices = topk_indices // vocab_size
+ topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
+ topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
+ topk_ids = topk_indices % vocab_size
+
+ # Update sequences for the K top-k new sequences.
+ topk_running_sequences[:, :, cur_len] = topk_ids
+
+ # we want to store the beam indices with batch information -> real beam index = beam index % num beams
+ batch_offset = mint.arange(batch_size).view(-1, 1) * num_beams
+ batch_modified_indices = topk_current_beam_indices + batch_offset
+ topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices
+
+ return topk_log_probs, topk_running_sequences, topk_running_beam_indices
+
+ def _get_running_beams_for_next_iteration(
+ self,
+ topk_log_probs: ms.Tensor,
+ topk_running_sequences: ms.Tensor,
+ topk_running_beam_indices: ms.Tensor,
+ next_token_hits_stopping_criteria: ms.Tensor,
+ num_beams: int,
+ ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor]:
+ """
+ Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the
+ best non-finished beams to continue beam search in the next iteration.
+ """
+ # To prevent these just finished sequences from being used in subsequent iterations, set their log probs
+ # to a very large negative value
+ topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(ms.float32) * -1.0e9
+
+ next_topk_indices = mint.topk(topk_running_log_probs, k=num_beams)[1]
+ running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices)
+ running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices)
+ running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices)
+ return running_sequences, running_beam_scores, running_beam_indices
+
+ def _update_finished_beams(
+ self,
+ sequences: ms.Tensor,
+ topk_running_sequences: ms.Tensor,
+ beam_scores: ms.Tensor,
+ topk_log_probs: ms.Tensor,
+ beam_indices: ms.Tensor,
+ topk_running_beam_indices: ms.Tensor,
+ is_early_stop_heuristic_unsatisfied: ms.Tensor,
+ is_sent_finished: ms.Tensor,
+ next_token_hits_stopping_criteria: ms.Tensor,
+ top_num_beam_mask: ms.Tensor,
+ num_beams: int,
+ cur_len: int,
+ decoder_prompt_len: int,
+ length_penalty: float,
+ early_stopping: Union[bool, str],
+ ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]:
+ """
+ Updates the finished beams if (and only if) there are new completed sequences that have a higher score than
+ the current finished sequences.
+ """
+ # Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the
+ # remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to
+ # continue.
+ did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :].to(ms.int32)
+
+ # Further process topk logits for the finished beams
+ # - add length penalty
+ topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty)
+ # - make sure no scores can be added anymore if beam is full and early stopping is on
+ beams_in_batch_are_full = mint.all(is_sent_finished, dim=-1, keepdim=True) & ms.Tensor(
+ early_stopping is True, ms.int32
+ )
+ topk_log_probs += beams_in_batch_are_full.to(ms.float32) * -1.0e9
+ # - make sure no scores can be added anymore if improvement is not possible
+ topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(ms.float32) * -1.0e9
+
+ # - make sure still running sequences cannot be chosen as finalized beam
+ topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
+
+ # Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized
+ # data with the new finalized sequences (if any, non-finalized sequences have a very large negative score
+ # in this step), and keep the best `num_beams` sequences.
+ merged_sequences = mint.cat((sequences, topk_running_sequences), dim=1)
+ merged_scores = mint.cat((beam_scores, topk_log_probs), dim=1)
+ merged_beam_indices = mint.cat((beam_indices, topk_running_beam_indices), dim=1)
+ merged_is_sent_finished = mint.cat((is_sent_finished, did_top_num_beams_just_finished.to(ms.bool_)), dim=1)
+ topk_merged_indices = mint.topk(merged_scores, k=num_beams)[1]
+ sequences = self._gather_beams(merged_sequences, topk_merged_indices)
+ beam_scores = self._gather_beams(merged_scores, topk_merged_indices)
+ beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices)
+ is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices)
+ return sequences, beam_scores, beam_indices, is_sent_finished
+
+ # end of auxiliary functions for beam search
+
+ def _beam_search(
+ self,
+ input_ids: ms.Tensor,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool,
+ **model_kwargs,
+ ) -> Union[GenerateBeamOutput, ms.Tensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+ If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
+ https://huggingface.co/blog/how-to-generate (especially the beam search section).
+ You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
+ (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
+ Parameters:
+ input_ids (`ms.Tensor` of shape `(batch_size*num_beams, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`:
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ generation_config ([`~generation.GenerationConfig`]):
+ The generation configuration to be used as parametrization of the decoding method.
+ synced_gpus (`bool`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+ Return:
+ [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
+ `ms.Tensor`: A `ms.Tensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+ """
+
+ # 1. init beam_search values
+ pad_token_id = generation_config._pad_token_tensor
+ eos_token_id = generation_config._eos_token_tensor
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
+ do_sample = generation_config.do_sample
+ early_stopping = generation_config.early_stopping
+ length_penalty = generation_config.length_penalty
+ max_length = generation_config.max_length
+ num_beams = generation_config.num_beams
+ num_return_sequences = generation_config.num_return_sequences
+
+ batch_size_unflattened, cur_len = input_ids.shape[:2]
+ batch_size = batch_size_unflattened // num_beams
+ # TODO (joao): standardize special cases
+ if self.__class__.__name__ == "MoshiDepthDecoder":
+ vocab_size = self.config.audio_vocab_size
+ elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
+ vocab_size = self.get_output_embeddings().out_features
+ else:
+ vocab_size = self.config.get_text_config().vocab_size
+ decoder_prompt_len = cur_len
+ this_peer_finished = False
+
+ # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
+ # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
+ # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
+ # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
+ beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
+ top_num_beam_mask = mint.cat(
+ (mint.ones((num_beams), dtype=ms.bool_), mint.zeros((beams_to_keep - num_beams), dtype=ms.bool_)),
+ dim=0,
+ )
+
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+
+ # (joao) feature lost in the refactor. Probably won't implement, hurts readbility with minimal gains (there
+ # are newer low-memory alternatives like the offloaded cache)
+ sequential = generation_config.low_memory
+ if sequential:
+ raise ValueError(
+ "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
+ "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
+ )
+
+ # 2. init output tuples
+ all_scores = () if (return_dict_in_generate and output_scores) else None
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ beam_indices = () if (return_dict_in_generate and output_logits) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
+
+ # 3. init running tensors and static-shaped placeholders
+
+ # per batch, beam-item holding current token in loop and completed sequences
+ output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
+ running_sequences = mint.full(
+ (batch_size, num_beams, max_length),
+ fill_value=output_fill_value,
+ dtype=ms.int64,
+ )
+ running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
+ sequences = running_sequences.copy() # .detach()
+
+ # per batch, beam-item score, logprobs
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
+ running_beam_scores = mint.zeros((batch_size, num_beams), dtype=ms.float32)
+ running_beam_scores[:, 1:] = -1e9
+ beam_scores = mint.full((batch_size, num_beams), fill_value=-1e9, dtype=ms.float32)
+
+ # per batch, beam-item state bit indicating if sentence has finished.
+ is_sent_finished = mint.zeros((batch_size, num_beams), dtype=ms.bool_)
+
+ # per batch state bit indicating if there is a possibility to improve the best finished sentence.
+ is_early_stop_heuristic_unsatisfied = mint.ones((batch_size, 1), dtype=ms.bool_)
+
+ # per batch, beam-item state bit indicating if there are valid continuations.
+ next_token_hits_stopping_criteria = mint.zeros((batch_size, num_beams), dtype=ms.bool_)
+
+ # per batch selected beam indices
+ running_beam_indices = mint.full((batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=ms.int32)
+ beam_indices = running_beam_indices.copy() # .detach()
+
+ # 4. run the generation loop
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
+ # a. Forward current tokens, obtain the logits
+ flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
+ model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
+
+ # prepare variable output controls (note: some models won't accept all output controls)
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
+
+ model_outputs = self(**model_inputs, return_dict=True)
+
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
+ model_kwargs = self._update_model_kwargs_for_generation(
+ model_outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+ if synced_gpus and this_peer_finished:
+ continue
+
+ logits = model_outputs.logits[:, -1, :].copy().float() # copy is needed to avoid keeping a hanging ref
+
+ # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
+ # `temperature`, ...), and add new logprobs to existing running logprobs scores.
+ log_probs = mint.nn.functional.log_softmax(logits, dim=-1)
+ log_probs = logits_processor(flat_running_sequences, log_probs)
+
+ # Store logits, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_logits:
+ raw_logits += (logits.copy(),)
+ if return_dict_in_generate and output_scores:
+ all_scores += (log_probs.copy(),)
+
+ if output_attentions:
+ decoder_attentions += (
+ (model_outputs.decoder_attentions,)
+ if self.config.is_encoder_decoder
+ else (model_outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (model_outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (model_outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (model_outputs.hidden_states,)
+ )
+
+ # This is needed to properly delete logits which may be very large for first iteration
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
+ del model_outputs
+
+ log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
+ log_probs = log_probs + running_beam_scores[:, :, None]
+ log_probs = mint.reshape(log_probs, (batch_size, num_beams * vocab_size))
+
+ # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
+ # continuations among all beams based on the accumulated scores.
+ topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
+ accumulated_log_probs=log_probs,
+ running_sequences=running_sequences,
+ running_beam_indices=running_beam_indices,
+ cur_len=cur_len,
+ decoder_prompt_len=decoder_prompt_len,
+ do_sample=do_sample,
+ beams_to_keep=beams_to_keep,
+ num_beams=num_beams,
+ vocab_size=vocab_size,
+ batch_size=batch_size,
+ )
+
+ # d. Check which running sequences have finished
+ next_token_hits_stopping_criteria = stopping_criteria(
+ self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
+ all_scores,
+ )
+ next_token_hits_stopping_criteria = self._unflatten_beam_dim(
+ next_token_hits_stopping_criteria, batch_size, beams_to_keep
+ )
+
+ # e. Get the non-finished running `num_beams` sequences for the next generation step
+ running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
+ topk_log_probs=topk_log_probs,
+ topk_running_sequences=topk_running_sequences,
+ topk_running_beam_indices=topk_running_beam_indices,
+ next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
+ num_beams=num_beams,
+ )
+
+ # f. Update the completed beams if a new high score in a finished sequence is found
+ sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
+ sequences=sequences,
+ topk_running_sequences=topk_running_sequences,
+ beam_scores=beam_scores,
+ topk_log_probs=topk_log_probs,
+ beam_indices=beam_indices,
+ topk_running_beam_indices=topk_running_beam_indices,
+ is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
+ is_sent_finished=is_sent_finished,
+ next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
+ top_num_beam_mask=top_num_beam_mask,
+ num_beams=num_beams,
+ cur_len=cur_len,
+ decoder_prompt_len=decoder_prompt_len,
+ length_penalty=length_penalty,
+ early_stopping=early_stopping,
+ )
+
+ # g. Prepare remaining data for the next iteration, including computing the stopping condition for
+ # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
+
+ # pluck the cache from the beam indices that will be used in the next iteration
+ # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
+ if model_kwargs.get("past_key_values", None) is not None:
+ beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
+ if hasattr(self, "_reorder_cache"):
+ model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
+ else:
+ model_kwargs["past_key_values"].reorder_cache(beam_idx)
+
+ cur_len = cur_len + 1
+ is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
+ is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
+ running_beam_scores=running_beam_scores,
+ beam_scores=beam_scores,
+ is_sent_finished=is_sent_finished,
+ cur_len=cur_len,
+ max_length=max_length,
+ decoder_prompt_len=decoder_prompt_len,
+ early_stopping=early_stopping,
+ length_penalty=length_penalty,
+ )
+ this_peer_finished = not self._beam_search_has_unfinished_sequences(
+ is_early_stop_heuristic_unsatisfied,
+ is_sent_finished,
+ next_token_hits_stopping_criteria,
+ early_stopping,
+ )
- @staticmethod
- def _beam_search_has_unfinished_sequences(
- running_beam_scores: ms.Tensor,
- beam_scores: ms.Tensor,
- is_sent_finished: ms.Tensor,
- next_token_hits_stopping_criteria: ms.Tensor,
- cur_len: int,
- max_length: int,
- decoder_prompt_len: int,
- early_stopping: Union[bool, str],
- length_penalty: float,
+ # 5. prepare outputs
+ # Take best beams for each batch (the score is sorted in descending order)
+ sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
+ beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
+ beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
+
+ # Crop the static-shaped tensors to the actual size.
+ # `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
+ # step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
+ # previous decoding iteration)
+ max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
+ output_length = decoder_prompt_len + max_generated_length
+ sequences = sequences[:, :output_length]
+ beam_indices = beam_indices[:, :max_generated_length]
+
+ if return_dict_in_generate:
+ if not output_scores:
+ beam_scores = None
+
+ if self.config.is_encoder_decoder:
+ return GenerateBeamEncoderDecoderOutput(
+ sequences=sequences,
+ sequences_scores=beam_scores,
+ scores=all_scores,
+ logits=raw_logits,
+ beam_indices=beam_indices,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateBeamDecoderOnlyOutput(
+ sequences=sequences,
+ sequences_scores=beam_scores,
+ scores=all_scores,
+ logits=raw_logits,
+ beam_indices=beam_indices,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return sequences
+
+ def _group_beam_search(
+ self,
+ input_ids: ms.Tensor,
+ beam_scorer: BeamScorer,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool,
+ **model_kwargs,
):
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **diverse beam search
+ decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+ Parameters:
+ input_ids (`ms.Tensor` of shape `(batch_size*num_beams, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ beam_scorer (`BeamScorer`):
+ An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
+ sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
+ logits_processor (`LogitsProcessorList`):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ generation_config ([`~generation.GenerationConfig`]):
+ The generation configuration to be used as parametrization of the decoding method.
+ synced_gpus (`bool`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ model_kwargs:
+ Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
+ model is an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
+ `ms.Tensor`: A `ms.Tensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
"""
- Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
- """
- # a. Can the open beams improve the top completed scores?
- # early_stopping == False -> apply heuristic = always get the best score from
- # `cur_len - decoder_prompt_len`. See the discussion below for more details.
- # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
- # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
- # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
- # `max_length` there.
- if early_stopping == "never" and length_penalty > 0.0:
- best_hypothetical_length = max_length - decoder_prompt_len
+ # init values
+ pad_token_id = generation_config._pad_token_tensor
+ eos_token_id = generation_config._eos_token_tensor
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
+
+ num_beams = beam_scorer.num_beams
+ num_beam_groups = beam_scorer.num_beam_groups
+ num_sub_beams = num_beams // num_beam_groups
+ batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
+
+ batch_beam_size, cur_len = input_ids.shape
+ model_kwargs = self._get_initial_cache_position(cur_len, model_kwargs)
+
+ if return_dict_in_generate and output_scores:
+ beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
else:
- best_hypothetical_length = cur_len - decoder_prompt_len
- best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
- worst_finished_score = mint.where(is_sent_finished, mint.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
- improvement_possible = mint.any(best_possible_running_score > worst_finished_score)
+ beam_indices = None
- # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
- # enabled, where we want to finish as soon as all beams have a completed sequence.
- exists_open_beam = ms.Tensor(
- ~(mint.all(is_sent_finished) & ms.Tensor(early_stopping is True, ms.int32)), ms.int32
- )
+ if num_beams * batch_size != batch_beam_size:
+ raise ValueError(
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
+ )
- # c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have
- # reached `max_length``
- valid_continuations = ~mint.all(next_token_hits_stopping_criteria)
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- return improvement_possible & exists_open_beam & valid_continuations
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
- def _get_top_k_continuations(
- self,
- accumulated_log_probs: ms.Tensor,
- running_sequences: ms.Tensor,
- running_beam_indices: ms.Tensor,
- cur_len: int,
- decoder_prompt_len: int,
- do_sample: bool,
- beams_to_keep: int,
- num_beams: int,
- vocab_size: int,
- batch_size: int,
- ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]:
- """
- Get top-K continuations given the accumulated log probs on the next token.
- A few notes to understand what's going on:
- 1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the
- top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated
- log-probabilities, or sample them without replacement using the accumulated scores
- 2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at
- least `num_beams` sequences remaining to continue the live beam search.
- 3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations
- selected in this step hit the stopping criteria.
- """
- # TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after
- # token selection. The function should be an argument exposed, so that custom scoring functions can be
- # defined.
+ # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
+ # the same group don't produce same tokens every time.
+ beam_scores = ops.full((batch_size, num_beams), -1e9, dtype=ms.float32)
+ beam_scores[:, ::num_sub_beams] = 0
+ beam_scores = beam_scores.view((batch_size * num_beams,))
- # Gather the top K scores from _all_ beams.
- if do_sample:
- topk_indices = mint.multinomial(
- mint.nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
+ this_peer_finished = False
+
+ decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
+ # predicted tokens in cur_len step
+ current_tokens = mint.zeros(batch_size * num_beams, dtype=input_ids.dtype)
+
+ # indices which will form the beams in the next time step
+ reordering_indices = mint.zeros(batch_size * num_beams, dtype=ms.int64)
+
+ # do one decoder step on all beams of all sentences in batch
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # prepare variable output controls (note: some models won't accept all output controls)
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
+
+ outputs = self(**model_inputs, return_dict=True)
+
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder,
)
- topk_log_probs = mint.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
- else:
- topk_log_probs, topk_indices = mint.topk(accumulated_log_probs, k=beams_to_keep)
+ if synced_gpus and this_peer_finished:
+ cur_len = cur_len + 1
+ continue
- # Gather K top beams, recover the beam index by floor division and token id by modulo division
- topk_current_beam_indices = topk_indices // vocab_size
- topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
- topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
- topk_ids = topk_indices % vocab_size
+ if output_scores:
+ processed_score = mint.zeros_like(outputs.logits[:, -1, :])
+ if output_logits:
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
+ # (the clone itself is always small)
+ raw_logit_score = outputs.logits[:, -1, :]
- # Update sequences for the K top-k new sequences.
- topk_running_sequences[:, :, cur_len] = topk_ids
+ for beam_group_idx in range(num_beam_groups):
+ group_start_idx = beam_group_idx * num_sub_beams
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
+ group_size = group_end_idx - group_start_idx
- # we want to store the beam indices with batch information -> real beam index = beam index % num beams
- batch_offset = mint.arange(batch_size).view(-1, 1) * num_beams
- batch_modified_indices = topk_current_beam_indices + batch_offset
- topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices
+ # indices of beams of current group among all sentences in batch
+ batch_group_indices = []
- return topk_log_probs, topk_running_sequences, topk_running_beam_indices
+ for batch_idx in range(batch_size):
+ batch_group_indices.extend(
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
+ )
+ group_input_ids = input_ids[batch_group_indices]
- def _get_running_beams_for_next_iteration(
- self,
- topk_log_probs: ms.Tensor,
- topk_running_sequences: ms.Tensor,
- topk_running_beam_indices: ms.Tensor,
- next_token_hits_stopping_criteria: ms.Tensor,
- num_beams: int,
- ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor]:
- """
- Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the
- best non-finished beams to continue beam search in the next iteration.
- """
- # To prevent these just finished sequences from being used in subsequent iterations, set their log probs
- # to a very large negative value
- topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(ms.float32) * -1.0e9
+ # select outputs of beams of current group only
+ # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
+ # .float() is needed to retain precision for later logits manipulations
+ next_token_logits = outputs.logits[batch_group_indices, -1, :].to(dtype=ms.float32)
- next_topk_indices = mint.topk(topk_running_log_probs, k=num_beams)[1]
- running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices)
- running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices)
- running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices)
- return running_sequences, running_beam_scores, running_beam_indices
+ next_token_scores = mint.nn.functional.log_softmax(
+ next_token_logits, dim=-1
+ ) # (batch_size * group_size, vocab_size)
+ vocab_size = next_token_scores.shape[-1]
+
+ next_token_scores_processed = logits_processor(
+ group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
+ )
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
+
+ if output_scores:
+ processed_score[batch_group_indices] = next_token_scores_processed
+
+ # reshape for beam search
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
+
+ # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
+ next_token_scores, next_tokens = mint.topk(
+ next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
+ )
+
+ next_indices = mint.div(next_tokens, vocab_size, rounding_mode="floor")
+ next_tokens = next_tokens % vocab_size
+
+ # stateless
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
+ beam_outputs = beam_scorer.process(
+ group_input_ids,
+ next_token_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ beam_indices=process_beam_indices,
+ group_index=beam_group_idx,
+ decoder_prompt_len=decoder_prompt_len,
+ )
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
+ beam_idx = beam_outputs["next_beam_indices"]
+
+ if return_dict_in_generate and output_scores:
+ beam_indices[beam_group_idx] = tuple(
+ beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
+ )
+
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
+ group_input_ids = mint.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
+
+ # (beam_idx // group_size) -> batch_idx
+ # (beam_idx % group_size) -> offset of idx inside the group
+ reordering_indices[batch_group_indices] = (
+ num_beams * mint.div(beam_idx, group_size, rounding_mode="floor")
+ + group_start_idx
+ + (beam_idx % group_size)
+ )
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (processed_score,)
+ if output_logits:
+ raw_logits += (raw_logit_score,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
+ )
+
+ input_ids = mint.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
+
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
+ # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
+ # (that way the memory peak does not include outputs.logits)
+ del outputs
- def _update_finished_beams(
- self,
- sequences: ms.Tensor,
- topk_running_sequences: ms.Tensor,
- beam_scores: ms.Tensor,
- topk_log_probs: ms.Tensor,
- beam_indices: ms.Tensor,
- topk_running_beam_indices: ms.Tensor,
- is_sent_finished: ms.Tensor,
- next_token_hits_stopping_criteria: ms.Tensor,
- top_num_beam_mask: ms.Tensor,
- num_beams: int,
- cur_len: int,
- decoder_prompt_len: int,
- length_penalty: float,
- early_stopping: Union[bool, str],
- ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]:
- """
- Updates the finished beams if (and only if) there are new completed sequences that have a higher score than
- the current finished sequences.
- """
- # Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the
- # remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to
- # continue.
- did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :].to(ms.int32)
+ # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
+ if model_kwargs.get("past_key_values", None) is not None:
+ if hasattr(self, "_reorder_cache"):
+ model_kwargs["past_key_values"] = self._reorder_cache(
+ model_kwargs["past_key_values"], reordering_indices
+ )
+ else:
+ model_kwargs["past_key_values"].reorder_cache(reordering_indices)
- # Further process topk logits for the finished beams
- # - add length penalty
- topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty)
- # - make sure no scores can be added anymore if beam is full and early stopping is on
- beams_in_batch_are_full = ops.all(is_sent_finished, axis=-1, keep_dims=True) & ms.Tensor(
- early_stopping is True, ms.int32
+ # increase cur_len
+ cur_len = cur_len + 1
+
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
+ this_peer_finished = True
+
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
+ sequence_outputs = beam_scorer.finalize(
+ input_ids,
+ beam_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ max_length=stopping_criteria.max_length,
+ beam_indices=final_beam_indices,
+ decoder_prompt_len=decoder_prompt_len,
)
- topk_log_probs += beams_in_batch_are_full.to(ms.float32) * -1.0e9
- # - make sure still running sequences cannot be chosen as finalized beam
- topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
- # Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized
- # data with the new finalized sequences (if any, non-finalized sequences have a very large negative score
- # in this step), and keep the best `num_beams` sequences.
- merged_sequences = mint.cat((sequences, topk_running_sequences), dim=1)
- merged_scores = mint.cat((beam_scores, topk_log_probs), dim=1)
- merged_beam_indices = mint.cat((beam_indices, topk_running_beam_indices), dim=1)
- merged_is_sent_finished = mint.cat((is_sent_finished, did_top_num_beams_just_finished.to(ms.bool_)), dim=1)
- topk_merged_indices = mint.topk(merged_scores, k=num_beams)[1]
- sequences = self._gather_beams(merged_sequences, topk_merged_indices)
- beam_scores = self._gather_beams(merged_scores, topk_merged_indices)
- beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices)
- is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices)
- return sequences, beam_scores, beam_indices, is_sent_finished
+ if return_dict_in_generate:
+ if not output_scores:
+ sequence_outputs["sequence_scores"] = None
- # end of auxiliary functions for beam search
+ if self.config.is_encoder_decoder:
+ return GenerateBeamEncoderDecoderOutput(
+ sequences=sequence_outputs["sequences"],
+ sequences_scores=sequence_outputs["sequence_scores"],
+ scores=scores,
+ logits=raw_logits,
+ beam_indices=sequence_outputs["beam_indices"],
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateBeamDecoderOnlyOutput(
+ sequences=sequence_outputs["sequences"],
+ sequences_scores=sequence_outputs["sequence_scores"],
+ scores=scores,
+ logits=raw_logits,
+ beam_indices=sequence_outputs["beam_indices"],
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return sequence_outputs["sequences"]
- def _beam_search(
+ def _constrained_beam_search(
self,
input_ids: ms.Tensor,
+ constrained_beam_scorer: ConstrainedBeamSearchScorer,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
@@ -3007,19 +4803,20 @@ def _beam_search(
**model_kwargs,
) -> Union[GenerateBeamOutput, ms.Tensor]:
r"""
- Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
- can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
- If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
- https://huggingface.co/blog/how-to-generate (especially the beam search section).
- You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
- (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
+ Generates sequences of token ids for models with a language modeling head using **constrained beam search
+ decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
Parameters:
input_ids (`ms.Tensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation.
+ constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
+ A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
+ sorted during generation, while satisfying a list of positive constraints. For more information, the
+ documentation of [`ConstrainedBeamSearchScorer`] should be read.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`:
+ stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
@@ -3030,15 +4827,15 @@ def _beam_search(
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
+
Return:
- [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
+ [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`ms.Tensor`: A `ms.Tensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
-
- # 1. init beam_search values
+ # init values
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
@@ -3046,51 +4843,252 @@ def _beam_search(
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
- do_sample = generation_config.do_sample
- early_stopping = generation_config.early_stopping
- length_penalty = generation_config.length_penalty
- max_length = generation_config.max_length
- num_beams = generation_config.num_beams
- num_return_sequences = generation_config.num_return_sequences
- batch_size_unflattened, cur_len = input_ids.shape
- batch_size = batch_size_unflattened // num_beams
- # TODO (joao): standardize special cases
- if self.__class__.__name__ == "MoshiDepthDecoder":
- vocab_size = self.config.audio_vocab_size
- elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
- vocab_size = self.get_output_embeddings().out_features
- else:
- vocab_size = self.config.get_text_config().vocab_size
- decoder_prompt_len = cur_len
- this_peer_finished = False
+ batch_size = len(constrained_beam_scorer._beam_hyps)
+ num_beams = constrained_beam_scorer.num_beams
- # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
- # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
- # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
- # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
- n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
- beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
- top_num_beam_mask = mint.cat(
- (mint.ones((num_beams), dtype=ms.bool_), mint.zeros((beams_to_keep - num_beams), dtype=ms.bool_)),
- dim=0,
+ batch_beam_size, cur_len = input_ids.shape[:2]
+ model_kwargs = self._get_initial_cache_position(cur_len, model_kwargs)
+
+ if num_beams * batch_size != batch_beam_size:
+ raise ValueError(
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
+ )
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ beam_indices = (
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
- # (joao) feature lost in the refactor. Probably won't implement, hurts readbility with minimal gains (there
- # are newer low-memory alternatives like the offloaded cache)
- sequential = generation_config.low_memory
- if sequential:
- raise ValueError(
- "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
- "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
+ beam_scores = mint.zeros((batch_size, num_beams), dtype=ms.float32)
+ beam_scores[:, 1:] = -1e9
+ beam_scores = beam_scores.view((batch_size * num_beams,))
+
+ this_peer_finished = False
+
+ decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # prepare variable output controls (note: some models won't accept all output controls)
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
+
+ outputs = self(**model_inputs, return_dict=True)
+
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+ if synced_gpus and this_peer_finished:
+ cur_len = cur_len + 1
+ continue
+
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
+ # (the clone itself is always small)
+ # .float() is needed to retain precision for later logits manipulations
+ next_token_logits = outputs.logits[:, -1, :].to(dtype=ms.float32)
+ next_token_scores = mint.nn.functional.log_softmax(
+ next_token_logits, dim=-1
+ ) # (batch_size * num_beams, vocab_size)
+
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
+
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
+ next_token_scores_processed
+ )
+
+ scores_for_all_vocab = next_token_scores.clone()
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_logits:
+ raw_logits += (next_token_logits,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
+ )
+
+ # reshape for beam search
+ vocab_size = next_token_scores.shape[-1]
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
+
+ # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
+ n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
+ next_token_scores, next_tokens = mint.topk(
+ next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
+ )
+
+ next_indices = (next_tokens / vocab_size).long()
+ next_tokens = next_tokens % vocab_size
+
+ # stateless
+ beam_outputs = constrained_beam_scorer.process(
+ input_ids,
+ next_token_scores,
+ next_tokens,
+ next_indices,
+ scores_for_all_vocab,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ beam_indices=beam_indices,
+ decoder_prompt_len=decoder_prompt_len,
)
+ beam_scores = beam_outputs["next_beam_scores"]
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
+ beam_idx = beam_outputs["next_beam_indices"]
+
+ input_ids = mint.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
+
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
+ # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
+ # (that way the memory peak does not include outputs.logits)
+ del outputs
+
+ # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
+ if model_kwargs.get("past_key_values", None) is not None:
+ if hasattr(self, "_reorder_cache"):
+ model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
+ else:
+ model_kwargs["past_key_values"].reorder_cache(beam_idx)
+
+ if return_dict_in_generate and output_scores:
+ beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))
+
+ # increase cur_len
+ cur_len = cur_len + 1
+
+ if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
+ this_peer_finished = True
+
+ sequence_outputs = constrained_beam_scorer.finalize(
+ input_ids,
+ beam_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ max_length=stopping_criteria.max_length,
+ beam_indices=beam_indices,
+ decoder_prompt_len=decoder_prompt_len,
+ )
+
+ if return_dict_in_generate:
+ if not output_scores:
+ sequence_outputs["sequence_scores"] = None
+ if self.config.is_encoder_decoder:
+ return GenerateBeamEncoderDecoderOutput(
+ sequences=sequence_outputs["sequences"],
+ sequences_scores=sequence_outputs["sequence_scores"],
+ scores=scores,
+ logits=raw_logits,
+ beam_indices=sequence_outputs["beam_indices"],
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateBeamDecoderOnlyOutput(
+ sequences=sequence_outputs["sequences"],
+ sequences_scores=sequence_outputs["sequence_scores"],
+ scores=scores,
+ logits=raw_logits,
+ beam_indices=sequence_outputs["beam_indices"],
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return sequence_outputs["sequences"]
+
+ def _assisted_decoding(
+ self,
+ input_ids: ms.Tensor,
+ candidate_generator: CandidateGenerator,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool,
+ streamer: Optional["BaseStreamer"],
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, ms.Tensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
+ **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
+ candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
+ models.
+
+ Parameters:
+ input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ candidate_generator (`CandidateGenerator`):
+ A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
+ more information, the documentation of [`CandidateGenerator`] should be read.
+ logits_processor (`LogitsProcessorList`):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ generation_config ([`~generation.GenerationConfig`]):
+ The generation configuration to be used as parametrization of the decoding method.
+ synced_gpus (`bool`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ model_kwargs:
+ Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
+ If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
+ `ms.Tensor`: A `ms.Tensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+ """
+ # init values
+ do_sample = generation_config.do_sample
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
- # 2. init output tuples
- all_scores = () if (return_dict_in_generate and output_scores) else None
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
- beam_indices = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
@@ -3102,205 +5100,180 @@ def _beam_search(
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
- # 3. init running tensors and static-shaped placeholders
-
- # per batch, beam-item holding current token in loop and completed sequences
- output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
- running_sequences = mint.full(
- (batch_size, num_beams, max_length),
- fill_value=output_fill_value,
- dtype=ms.int64,
- )
- running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
- sequences = running_sequences.copy() # .detach()
+ # keep track of which sequences are already finished
+ batch_size, cur_len = input_ids.shape[:2]
+ unfinished_sequences = mint.ones(batch_size, dtype=ms.int64)
+ model_kwargs = self._get_initial_cache_position(cur_len, model_kwargs)
- # per batch, beam-item score, logprobs
- # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
- # of the first beam are considered to avoid sampling the exact same tokens across all beams.
- running_beam_scores = mint.zeros((batch_size, num_beams), dtype=ms.float32)
- running_beam_scores[:, 1:] = -1e9
- beam_scores = mint.full((batch_size, num_beams), fill_value=-1e9, dtype=ms.float32)
+ this_peer_finished = False
+ is_first_iteration = True # to preserve the same API in the output as other generation methods
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
+ cur_len = input_ids.shape[1]
- # per batch, beam-item state bit indicating if sentence has finished.
- is_sent_finished = mint.zeros((batch_size, num_beams), dtype=ms.bool_)
+ # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
+ candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
- # per batch, beam-item state bit indicating if there are valid continuations.
- next_token_hits_stopping_criteria = mint.zeros((batch_size, num_beams), dtype=ms.bool_)
+ candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
+ is_done_candidate = stopping_criteria(candidate_input_ids, None)
- # per batch selected beam indices
- running_beam_indices = mint.full((batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=ms.int32)
- beam_indices = running_beam_indices.copy() # .detach()
+ # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
+ # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
+ # we use this forward pass to also pick the subsequent logits in the original model.
- step = 0
- s_time = time.time()
- graph_compiled_time_buffer = []
+ # 2.1. Prepare the model inputs
+ candidate_kwargs = copy.copy(model_kwargs)
+ candidate_kwargs = _prepare_attention_mask(
+ candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
+ )
+ candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
+ if "cache_position" in candidate_kwargs:
+ candidate_kwargs["cache_position"] = mint.cat(
+ (
+ candidate_kwargs["cache_position"],
+ mint.arange(cur_len, cur_len + candidate_length, dtype=ms.int64),
+ ),
+ dim=0,
+ )
- # 4. run the generation loop
- while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
- # a. Forward current tokens, obtain the logits
- flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
- model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
+ model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
+ if "logits_to_keep" in model_inputs:
+ model_inputs["logits_to_keep"] = candidate_length + 1
+ # 2.2. Run a forward pass on the candidate sequence
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
- model_outputs = self(**model_inputs, return_dict=True)
+ outputs = self(**model_inputs)
+
+ # 2.3. Process the new logits
+ # .float() is needed to retain precision for later logits manipulations
+ new_logits = outputs.logits[:, -candidate_length - 1 :].to(
+ dtype=ms.float32
+ ) # excludes the input prompt if present
+ next_token_logits = new_logits.clone()
+ if len(logits_processor) > 0:
+ for i in range(candidate_length + 1):
+ new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
+
+ # 3. Select the accepted tokens. There are two possible cases:
+ # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
+ # 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192).
+ if do_sample and candidate_logits is not None:
+ valid_tokens, n_matches = _speculative_sampling(
+ candidate_input_ids,
+ candidate_logits,
+ candidate_length,
+ new_logits,
+ is_done_candidate,
+ )
+
+ # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
+ # original model logits with the candidate tokens. We can keep the candidate tokens until the first
+ # mismatch, or until the max length is reached.
+ else:
+ if do_sample:
+ probs = new_logits.softmax(dim=-1)
+ selected_tokens = mint.nn.functional.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
+ else:
+ selected_tokens = new_logits.argmax(dim=-1)
+
+ candidate_new_tokens = candidate_input_ids[:, cur_len:]
+ n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
+
+ # Ensure we don't generate beyond max_len or an EOS token
+ if is_done_candidate and n_matches == candidate_length:
+ n_matches -= 1
+ valid_tokens = selected_tokens[:, : n_matches + 1]
+
+ # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
+ # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
+ # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
+ # is no match.
+
+ # 4.1. Get the valid continuation, after the matching tokens
+ input_ids = mint.cat((input_ids, valid_tokens), dim=-1)
+ if streamer is not None:
+ streamer.put(valid_tokens.cpu())
+ new_cur_len = input_ids.shape[1]
+
+ # 4.2. Discard past key values relative to unused assistant tokens
+ outputs.past_key_values.crop(new_cur_len - 1)
+
+ # 5. Update the candidate generation strategy if needed
+ candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
- model_outputs,
+ outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
+ num_new_tokens=n_matches + 1,
)
if synced_gpus and this_peer_finished:
continue
- step_time = time.time() - s_time
- if step < 2:
- print(f"==> sampling, step: {step}, time cost: {step_time:.5f}s")
- else:
- graph_compiled_time_buffer.append(step_time)
- token_speed = len(graph_compiled_time_buffer) / sum(graph_compiled_time_buffer)
- print(
- f"==> sampling, step: {step}, time cost: {step_time:.5f}s, running avg speed: {token_speed:.5f}token/s"
- )
- s_time = time.time()
- step += 1
-
- logits = model_outputs.logits[:, -1, :].copy().float() # copy is needed to avoid keeping a hanging ref
-
- # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
- # `temperature`, ...), and add new logprobs to existing running logprobs scores.
- log_probs = mint.nn.functional.log_softmax(logits, dim=-1)
- log_probs = logits_processor(flat_running_sequences, log_probs)
-
- # Store logits, attentions and hidden_states when required
+ # Store scores, attentions and hidden_states when required
+ # Assistant: modified to append one tuple element per token, as in the other generation methods.
if return_dict_in_generate:
+ newly_added_length = n_matches + 1
+ if output_scores:
+ scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
if output_logits:
- raw_logits += (logits.copy(),)
- if return_dict_in_generate and output_scores:
- all_scores += (log_probs.copy(),)
+ raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
+ newly_added_length = new_cur_len if is_first_iteration else newly_added_length
if output_attentions:
- decoder_attentions += (
- (model_outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (model_outputs.attentions,)
- )
if self.config.is_encoder_decoder:
- cross_attentions += (model_outputs.cross_attentions,)
-
+ cross_attentions = _split_model_outputs(
+ cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
+ )
+ decoder_attentions = _split_model_outputs(
+ decoder_attentions,
+ outputs.decoder_attentions,
+ cur_len,
+ newly_added_length,
+ is_decoder_attention=True,
+ )
+ # some (V)LLMs have hard requirement on SDPA and thus never return attn
+ elif outputs.attentions[0] is not None:
+ decoder_attentions = _split_model_outputs(
+ decoder_attentions,
+ outputs.attentions,
+ cur_len,
+ newly_added_length,
+ is_decoder_attention=True,
+ )
if output_hidden_states:
- decoder_hidden_states += (
- (model_outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (model_outputs.hidden_states,)
- )
-
- # This is needed to properly delete logits which may be very large for first iteration
- # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
- del model_outputs
-
- log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
- log_probs = log_probs + running_beam_scores[:, :, None]
- log_probs = mint.reshape(log_probs, (batch_size, num_beams * vocab_size))
-
- # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
- # continuations among all beams based on the accumulated scores.
- topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
- accumulated_log_probs=log_probs,
- running_sequences=running_sequences,
- running_beam_indices=running_beam_indices,
- cur_len=cur_len,
- decoder_prompt_len=decoder_prompt_len,
- do_sample=do_sample,
- beams_to_keep=beams_to_keep,
- num_beams=num_beams,
- vocab_size=vocab_size,
- batch_size=batch_size,
- )
-
- # d. Check which running sequences have finished
- next_token_hits_stopping_criteria = stopping_criteria(
- self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
- all_scores,
- )
- next_token_hits_stopping_criteria = self._unflatten_beam_dim(
- next_token_hits_stopping_criteria, batch_size, beams_to_keep
- )
-
- # e. Get the non-finished running `num_beams` sequences for the next generation step
- running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
- topk_log_probs=topk_log_probs,
- topk_running_sequences=topk_running_sequences,
- topk_running_beam_indices=topk_running_beam_indices,
- next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
- num_beams=num_beams,
- )
-
- # f. Update the completed beams if a new high score in a finished sequence is found
- sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
- sequences=sequences,
- topk_running_sequences=topk_running_sequences,
- beam_scores=beam_scores,
- topk_log_probs=topk_log_probs,
- beam_indices=beam_indices,
- topk_running_beam_indices=topk_running_beam_indices,
- is_sent_finished=is_sent_finished,
- next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
- top_num_beam_mask=top_num_beam_mask,
- num_beams=num_beams,
- cur_len=cur_len,
- decoder_prompt_len=decoder_prompt_len,
- length_penalty=length_penalty,
- early_stopping=early_stopping,
- )
+ if self.config.is_encoder_decoder:
+ decoder_hidden_states = _split_model_outputs(
+ decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
+ )
+ else:
+ decoder_hidden_states = _split_model_outputs(
+ decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
+ )
- # g. Prepare remaining data for the next iteration, including computing the stopping condition for
- # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
+ this_peer_finished = unfinished_sequences.max() == 0
+ is_first_iteration = False
- # pluck the cache from the beam indices that will be used in the next iteration
- if model_kwargs.get("past_key_values", None) is not None:
- model_kwargs["past_key_values"] = self._temporary_reorder_cache(
- past_key_values=model_kwargs["past_key_values"],
- beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]),
- )
+ if streamer is not None:
+ streamer.end()
- cur_len = cur_len + 1
- this_peer_finished = not self._beam_search_has_unfinished_sequences(
- running_beam_scores,
- beam_scores,
- is_sent_finished,
- next_token_hits_stopping_criteria,
- cur_len,
- max_length,
- decoder_prompt_len,
- early_stopping,
- length_penalty,
+ if (
+ hasattr(candidate_generator, "assistant_model")
+ and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
+ ):
+ candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
+ candidate_generator.num_assistant_tokens
)
-
- # 5. prepare outputs
- # Take best beams for each batch (the score is sorted in descending order)
- sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
- beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
- beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
-
- # Crop the static-shaped tensors to the actual size
- sequences = sequences[:, :cur_len]
- beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
-
if return_dict_in_generate:
- if not output_scores:
- beam_scores = None
-
if self.config.is_encoder_decoder:
- return GenerateBeamEncoderDecoderOutput(
- sequences=sequences,
- sequences_scores=beam_scores,
- scores=all_scores,
+ return GenerateEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
logits=raw_logits,
- beam_indices=beam_indices,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
@@ -3309,15 +5282,285 @@ def _beam_search(
past_key_values=model_kwargs.get("past_key_values"),
)
else:
- return GenerateBeamDecoderOnlyOutput(
- sequences=sequences,
- sequences_scores=beam_scores,
- scores=all_scores,
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
logits=raw_logits,
- beam_indices=beam_indices,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
- return sequences
+ return input_ids
+
+ def _prefill_chunking(self, input_ids: ms.Tensor, generation_config: GenerationConfig, **model_kwargs):
+ chunk_size = generation_config.prefill_chunk_size
+ # Only chunk up the token just before last, so that decoding is completely performed outside this function
+ # (here we simply prefill the cache)
+ input_chunks = mint.split(input_ids[:, :-1], chunk_size, dim=-1)
+
+ if "past_key_values" not in model_kwargs:
+ raise ValueError("Cannot use prefill chunking without a cache")
+
+ model_forward = self.construct
+
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
+ if compile_forward:
+ model_forward = self.get_compiled_call(generation_config.compile_config)
+
+ attention_mask = model_kwargs.pop("attention_mask", None)
+
+ past_length = 0
+ for input_chunk in input_chunks:
+ current_length = past_length + input_chunk.shape[-1]
+ # Prepare inputs
+ if attention_mask is not None:
+ model_kwargs["attention_mask"] = attention_mask[:, :current_length]
+ model_kwargs["cache_position"] = mint.arange(past_length, current_length, dtype=ms.int64)
+ model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
+ model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
+
+ outputs = model_forward(**model_inputs, return_dict=True)
+
+ model_kwargs["past_key_values"] = outputs.past_key_values
+ past_length = current_length
+
+ model_kwargs["attention_mask"] = attention_mask
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
+ _ = model_kwargs.pop("position_ids", None)
+
+ return model_kwargs
+
+
+def _speculative_sampling(
+ candidate_input_ids,
+ candidate_logits,
+ candidate_length,
+ new_logits,
+ is_done_candidate,
+):
+ """
+ Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
+ the selected tokens, as well as the number of candidate matches.
+
+ NOTE: Unless otherwise stated, the variable names match those in the paper.
+ """
+ new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
+ # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
+ # selected by the assistant, respectively.
+ q = candidate_logits.softmax(dim=-1)
+ q_i = q[:, mint.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
+ p = new_logits.softmax(dim=-1)
+ p_i = p[:, mint.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
+ probability_ratio = p_i / q_i
+
+ # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
+ # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
+ # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
+ r_i = mint.rand_like(probability_ratio)
+ is_accepted = r_i <= probability_ratio
+ n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
+
+ # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
+ if is_done_candidate and n_matches == candidate_length:
+ # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
+ # due to acceptance on EOS we fix `n_matches`
+ n_matches -= 1
+ valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
+ else:
+ # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
+ gamma = candidate_logits.shape[1]
+ p_n_plus_1 = p[:, n_matches, :]
+ if n_matches < gamma:
+ q_n_plus_1 = q[:, n_matches, :]
+ p_prime = mint.clamp((p_n_plus_1 - q_n_plus_1), min=0)
+ p_prime.div_(p_prime.sum())
+ else:
+ p_prime = p_n_plus_1
+ t = mint.nn.functional.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
+
+ # The selected tokens include the matches (if any) plus the next sampled tokens
+ if n_matches > 0:
+ valid_tokens = mint.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
+ else:
+ valid_tokens = t
+
+ return valid_tokens, n_matches
+
+
+def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
+ """
+ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
+ where each member corresponds to a single generated token.
+ """
+ # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
+ # prompt.
+ if len(outputs) == 0:
+ new_tuple = ()
+ for layer in new_outputs:
+ last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
+ new_tuple += (layer[..., :cur_len, :last_dim_size],)
+ outputs += (new_tuple,)
+ # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
+ cur_len += 1
+ added_len -= cur_len
+
+ for i in range(added_len):
+ new_tuple = ()
+ for layer in new_outputs:
+ last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
+ new_tuple += (layer[..., i : i + 1, :last_dim_size],)
+ outputs += (new_tuple,)
+ return outputs
+
+
+def _ranking_fast(
+ context_hidden: ms.Tensor,
+ next_hidden: ms.Tensor,
+ next_top_k_probs: ms.Tensor,
+ cosine_matrix_mask: ms.Tensor,
+ alpha: float,
+ beam_width: int,
+) -> ms.Tensor:
+ """
+ Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
+ in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
+ row in the batch.
+ """
+ norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
+ norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
+ cosine_matrix = mint.matmul(norm_context_hidden, norm_next_hidden.swapaxes(1, 2)).squeeze(-1) # [B*K, S]
+
+ # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
+ # Using a large negative value for masked positions
+ cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype)
+ cosine_matrix_mask = (1 - cosine_matrix_mask) * dtype_to_min(cosine_matrix.dtype)
+ cosine_matrix = cosine_matrix + cosine_matrix_mask
+
+ degeneration_penalty, _ = mint.max(cosine_matrix, dim=-1) # [B*K]
+ next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
+ contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
+ contrastive_score = mint.stack(mint.split(contrastive_score, beam_width)) # [B, K]
+ _, selected_idx = contrastive_score.max(dim=-1) # [B]
+ return selected_idx
+
+
+def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
+ """
+ Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
+ specific ModelOutput subclass from the list provided.
+ """
+ if not model_outputs:
+ raise ValueError("Input list is empty.")
+
+ # Infer the class from the first object in the list
+ model_output_cls = type(model_outputs[0])
+
+ # Ensure all objects are of the same type
+ if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
+ raise ValueError("All elements in the list should be of the same type.")
+
+ # Helper function to concat tensors or tuples of tensors
+ def _concat(data):
+ """
+ Reverse of `_split` function above.
+ """
+ if any(data is None for data in data):
+ return None
+ if isinstance(data[0], ms.Tensor):
+ return mint.cat(data, dim=0)
+ elif isinstance(data[0], tuple):
+ # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
+ if isinstance(data[0][0], tuple):
+ return tuple(
+ tuple(mint.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
+ for i in range(len(data[0]))
+ )
+ else:
+ return tuple(mint.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
+ elif isinstance(data[0], (int, float)):
+ # If the elements are integers or floats, return a tensor
+ return ms.Tensor(data)
+ else:
+ raise TypeError(f"Unexpected attribute type: {type(data[0])}")
+
+ # Use a dictionary comprehension to gather attributes from all objects and concatenate them
+ concatenated_data = {
+ k: _concat([getattr(model_output, k) for model_output in model_outputs])
+ for k in model_output_cls.__dataclass_fields__.keys()
+ }
+
+ # Return a new object of the inferred class with the concatenated attributes
+ return model_output_cls(**concatenated_data)
+
+
+def _relative_top_filter(
+ scores: ms.Tensor,
+ baseline_scores: ms.Tensor,
+ relative_top: float = 0.1,
+ filter_value: float = -float("Inf"),
+ base_filter_value=-1e-3,
+ min_tokens_to_keep: int = 1,
+) -> ms.Tensor:
+ """
+ Apply filtering to only keep tokens with a probability above a certain threshold.
+ The threshold is defined as `relative_top` * max probability in the distribution.
+ """
+ scores_normalized = scores.log_softmax(dim=-1)
+ baseline_scores_normalized = baseline_scores.log_softmax(dim=-1)
+ sorted_logits, sorted_indices = mint.sort(scores_normalized, descending=True)
+ min_thresh = sorted_logits[..., min_tokens_to_keep - 1]
+ probs_max = mint.max(scores_normalized, dim=-1).values
+ probs_thresh = probs_max + np.log(relative_top)
+ probs_thresh = mint.min(min_thresh, probs_thresh)
+ probs_thresh = probs_thresh.unsqueeze(-1)
+ baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value
+ scores_normalized[scores_normalized < probs_thresh] = filter_value
+ return scores_normalized, baseline_scores_normalized
+
+
+def _dola_select_contrast(
+ candidate_premature_layers: list[int],
+ candidate_premature_logits: dict[int, ms.Tensor],
+ final_logits: ms.Tensor,
+) -> ms.Tensor:
+ if len(candidate_premature_layers) == 1:
+ base_logits = candidate_premature_logits[candidate_premature_layers[0]]
+ final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
+ logits = final_logits - base_logits
+ return logits
+
+ # 1. Stacking all premature_layers into a new dimension
+ stacked_premature_layers = mint.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0)
+
+ # 2. Calculate the softmax values for mature_layer and all premature_layers
+ # shape: (batch_size, vocab_size)
+ softmax_mature_layer = F.softmax(final_logits, dim=-1)
+ # shape: (num_premature_layers, batch_size, vocab_size)
+ softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1)
+
+ # 3. Calculate the average distribution
+ # shape: (num_premature_layers, batch_size, vocab_size)
+ avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers)
+
+ # 4. Calculate log-softmax for the KL divergence
+ # shape: (batch_size, vocab_size)
+ log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1)
+ # shape: (num_premature_layers, batch_size, vocab_size)
+ log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1)
+
+ # 5. Calculate the KL divergences and then the JS divergences
+ # shape: (num_premature_layers, batch_size)
+ kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1)
+ # shape: (num_premature_layers, batch_size)
+ kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1)
+ js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size)
+
+ # 6. Reduce the batchmean
+ js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
+ premature_layer = candidate_premature_layers[int(js_divs.argmax().item())]
+
+ base_logits = candidate_premature_logits[premature_layer]
+ final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
+ logits = final_logits - base_logits
+ return logits
diff --git a/mindone/transformers/image_processing_base.py b/mindone/transformers/image_processing_base.py
index 18638c8a77..d461c66fd2 100644
--- a/mindone/transformers/image_processing_base.py
+++ b/mindone/transformers/image_processing_base.py
@@ -40,7 +40,8 @@
)
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
-from .utils import add_model_info_to_auto_map, add_model_info_to_custom_pipelines, is_vision_available
+from .image_utils import is_valid_image
+from .utils import is_vision_available
if is_vision_available():
from PIL import Image
@@ -136,7 +137,7 @@ def from_pretrained(
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
@@ -352,13 +353,13 @@ def get_image_processor_dict(
revision=revision,
subfolder=subfolder,
)
- except EnvironmentError:
+ except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
- raise EnvironmentError(
+ raise OSError(
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
@@ -367,12 +368,12 @@ def get_image_processor_dict(
try:
# Load image_processor dict
- with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
+ with open(resolved_image_processor_file, encoding="utf-8") as reader:
text = reader.read()
image_processor_dict = json.loads(text)
except json.JSONDecodeError:
- raise EnvironmentError(
+ raise OSError(
f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
)
@@ -382,14 +383,6 @@ def get_image_processor_dict(
logger.info(
f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
)
- if "auto_map" in image_processor_dict:
- image_processor_dict["auto_map"] = add_model_info_to_auto_map(
- image_processor_dict["auto_map"], pretrained_model_name_or_path
- )
- if "custom_pipelines" in image_processor_dict:
- image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
- image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
- )
return image_processor_dict, kwargs
@@ -464,7 +457,7 @@ def from_json_file(cls, json_file: Union[str, os.PathLike]):
A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
instantiated from that JSON file.
"""
- with open(json_file, "r", encoding="utf-8") as reader:
+ with open(json_file, encoding="utf-8") as reader:
text = reader.read()
image_processor_dict = json.loads(text)
return cls(**image_processor_dict)
@@ -510,11 +503,6 @@ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
Register this class with a given auto class. This should only be used for custom image processors as the ones
in the library are already mapped with `AutoImageProcessor `.
-
-
- This API is experimental and may have some slight breaking changes in the next releases.
-
-
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
@@ -549,6 +537,8 @@ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
response = requests.get(image_url_or_urls, stream=True, headers=headers)
response.raise_for_status()
return Image.open(BytesIO(response.content))
+ elif is_valid_image(image_url_or_urls):
+ return image_url_or_urls
else:
raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
diff --git a/mindone/transformers/image_processing_utils_fast.py b/mindone/transformers/image_processing_utils_fast.py
index abcb01e6ea..117d675e20 100644
--- a/mindone/transformers/image_processing_utils_fast.py
+++ b/mindone/transformers/image_processing_utils_fast.py
@@ -16,14 +16,13 @@
# limitations under the License.
from collections.abc import Iterable
+from copy import deepcopy
from functools import lru_cache, partial
from typing import Any, Optional, TypedDict, Union
import numpy as np
from PIL import Image
-from transformers.utils import add_start_docstrings, logging
-
-from mindspore import mint
+from transformers.utils import auto_docstring, logging
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from .image_transforms import (
@@ -55,6 +54,8 @@
if is_mindspore_available():
import mindspore as ms
+ import mindspore.mint.nn.functional as F
+ from mindspore import mint
from mindspore.dataset import vision
from mindspore.dataset.vision import Inter as InterpolationMode
@@ -174,103 +175,7 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False):
input_data_format: Optional[Union[str, ChannelDimension]]
-BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r"""
-
- Args:
- do_resize (`bool`, *optional*, defaults to `self.do_resize`):
- Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
- `do_resize` parameter in the `preprocess` method.
- size (`dict`, *optional*, defaults to `self.size`):
- Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
- method.
- default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
- Whether to default to a square image when resizing, if size is an int.
- resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
- Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
- overridden by the `resample` parameter in the `preprocess` method.
- do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
- Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
- `preprocess` method.
- crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`):
- Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
- method.
- do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
- Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
- `do_rescale` parameter in the `preprocess` method.
- rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
- Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
- overridden by the `rescale_factor` parameter in the `preprocess` method.
- do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
- Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
- method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
- image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
- Mean to use if normalizing the image. This is a float or list of floats the length of the number of
- channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
- overridden by the `image_mean` parameter in the `preprocess` method.
- image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
- Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
- number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
- Can be overridden by the `image_std` parameter in the `preprocess` method.
- do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
- Whether to convert the image to RGB.
- return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`):
- Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
- data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`):
- Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.
- input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`):
- The channel dimension format for the input image. If unset, the channel dimension format is inferred
- from the input image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- - `"none"` or `ChannelDimension.NONE`: image in (height, width) format."""
-
-BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r"""
- Preprocess an image or batch of images.
-
- Args:
- images (`ImageInput`):
- Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
- passing in images with pixel values between 0 and 1, set `do_rescale=False`.
- do_resize (`bool`, *optional*, defaults to `self.do_resize`):
- Whether to resize the image.
- size (`Dict[str, int]`, *optional*, defaults to `self.size`):
- Describes the maximum input dimensions to the model.
- resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`):
- Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
- has an effect if `do_resize` is set to `True`.
- do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
- Whether to center crop the image.
- crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
- Size of the output image after applying `center_crop`.
- do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
- Whether to rescale the image.
- rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
- Rescale factor to rescale the image by if `do_rescale` is set to `True`.
- do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
- Whether to normalize the image.
- image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
- Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
- image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
- Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
- `True`.
- do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
- Whether to convert the image to RGB.
- return_tensors (`str` or `TensorType`, *optional*, defaults to `self.return_tensors`):
- Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
- data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.data_format`):
- Only `ChannelDimension.FIRST` is supported. Added for compatibility with slow processors.
- input_data_format (`ChannelDimension` or `str`, *optional*, defaults to `self.input_data_format`):
- The channel dimension format for the input image. If unset, the channel dimension format is inferred
- from the input image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- - `"none"` or `ChannelDimension.NONE`: image in (height, width) format."""
-
-
-@add_start_docstrings(
- "Constructs a fast base image processor.",
- BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
-)
+@auto_docstring
class BaseImageProcessorFast(BaseImageProcessor):
resample = None
image_mean = None
@@ -310,7 +215,10 @@ def __init__(
if kwarg is not None:
setattr(self, key, kwarg)
else:
- setattr(self, key, getattr(self, key, None))
+ setattr(self, key, deepcopy(getattr(self, key, None)))
+
+ # get valid kwargs names
+ self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
def resize(
self,
@@ -328,7 +236,7 @@ def resize(
Image to resize.
size (`SizeDict`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
- resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
+ interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
Returns:
@@ -367,6 +275,18 @@ def resize(
image = Image.fromarray(image)
return ms.tensor(np.array(resize(image))).permute(2, 0, 1)
+ @staticmethod
+ def compile_friendly_resize(
+ image: "ms.Tensor",
+ new_size: tuple[int, int],
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ ) -> "ms.Tensor":
+ """
+ A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
+ """
+ raise NotImplementedError("This method is not implemented for mindspore")
+
def rescale(
self,
image: "ms.Tensor",
@@ -519,6 +439,7 @@ def filter_out_unused_kwargs(self, kwargs: dict):
def _prepare_images_structure(
self,
images: ImageInput,
+ expected_ndims: int = 3,
) -> ImageInput:
"""
Prepare the images structure for processing.
@@ -530,7 +451,7 @@ def _prepare_images_structure(
Returns:
`ImageInput`: The images with a valid nesting.
"""
- return make_flat_list_of_images(images)
+ return make_flat_list_of_images(images, expected_ndims=expected_ndims)
def _process_image(
self,
@@ -552,6 +473,9 @@ def _process_image(
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
image = ms.from_numpy(image).contiguous()
+ # If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
+ if image.ndim == 2:
+ image = image.unsqueeze(0)
# Infer the channel dimension format if not provided
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
@@ -562,25 +486,44 @@ def _process_image(
return image
- def _prepare_input_images(
+ def _prepare_image_like_inputs(
self,
images: ImageInput,
- do_convert_rgb: bool = None,
+ do_convert_rgb: Optional[bool] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ expected_ndims: int = 3,
) -> list["ms.Tensor"]:
"""
- Prepare the input images for processing.
+ Prepare image-like inputs for processing.
+
+ Args:
+ images (`ImageInput`):
+ The image-like inputs to process.
+ do_convert_rgb (`bool`, *optional*):
+ Whether to convert the images to RGB.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The input data format of the images.
+ expected_ndims (`int`, *optional*):
+ The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)
+
+ Returns:
+ List[`ms.Tensor`]: The processed images.
"""
- images = self._prepare_images_structure(images)
- process_image_fn = partial(
- self._process_image,
- do_convert_rgb=do_convert_rgb,
- input_data_format=input_data_format,
+
+ # Get structured images (potentially nested)
+ images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
+
+ process_image_partial = partial(
+ self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format
)
- # todo: yoni - check if we can parallelize this efficiently
- processed_images = []
- for image in images:
- processed_images.append(process_image_fn(image))
+
+ # Check if we have nested structure, assuming the nesting is consistent
+ has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))
+
+ if has_nested_structure:
+ processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
+ else:
+ processed_images = [process_image_partial(img) for img in images]
return processed_images
@@ -654,22 +597,21 @@ def _validate_preprocess_kwargs(
data_format=data_format,
)
- @add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
- def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
- logger.warning("Please use FastImageProcessor cautiously. It may not have better inference performance!")
- validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
+ def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
+ return self.preprocess(images, *args, **kwargs)
+
+ @auto_docstring
+ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
+ # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
- for kwarg_name in self.valid_kwargs.__annotations__:
+ for kwarg_name in self._valid_kwargs_names:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
- # Prepare input images
- images = self._prepare_input_images(
- images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format
- )
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
@@ -679,9 +621,13 @@ def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProces
# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
+
+ # Check if resample is an int before checking if it's an instance of PILImageResampling
+ # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
+ # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
kwargs["interpolation"] = (
pil_mindspore_interpolation_mapping[resample]
- if isinstance(resample, (PILImageResampling, int))
+ if isinstance(resample, (int, PILImageResampling))
else resample
)
@@ -689,7 +635,28 @@ def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProces
kwargs.pop("default_to_square")
kwargs.pop("data_format")
- return self._preprocess(images=images, **kwargs)
+ return self._preprocess_image_like_inputs(
+ images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, **kwargs
+ )
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ *args,
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ **kwargs: Unpack[DefaultFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ To be overriden by subclasses when image-like inputs other than images should be processed.
+ It can be used for segmentation maps, depth maps, etc.
+ """
+ # Prepare input images
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format
+ )
+ return self._preprocess(images, *args, **kwargs)
def _preprocess(
self,
@@ -741,6 +708,7 @@ def _preprocess(
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
+ encoder_dict.pop("_valid_kwargs_names", None)
return encoder_dict
diff --git a/mindone/transformers/image_transforms.py b/mindone/transformers/image_transforms.py
index 8ff53d7f5e..8aa9c0aca1 100644
--- a/mindone/transformers/image_transforms.py
+++ b/mindone/transformers/image_transforms.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import warnings
+from collections import defaultdict
from collections.abc import Collection, Iterable
from math import ceil
from typing import Optional, Union
@@ -48,7 +48,9 @@ def to_channel_dimension_format(
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
) -> np.ndarray:
"""
- Converts `image` to the channel dimension format specified by `channel_dim`.
+ Converts `image` to the channel dimension format specified by `channel_dim`. The input
+ can have arbitrary number of leading dimensions. Only last three dimension will be permuted
+ to format the `image`.
Args:
image (`numpy.ndarray`):
@@ -72,9 +74,11 @@ def to_channel_dimension_format(
return image
if target_channel_dim == ChannelDimension.FIRST:
- image = image.transpose((2, 0, 1))
+ axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2]
+ image = image.transpose(axes)
elif target_channel_dim == ChannelDimension.LAST:
- image = image.transpose((1, 2, 0))
+ axes = list(range(image.ndim - 3)) + [image.ndim - 2, image.ndim - 1, image.ndim - 3]
+ image = image.transpose(axes)
else:
raise ValueError(f"Unsupported channel dimension format: {channel_dim}")
@@ -399,7 +403,7 @@ def normalize(
The channel dimension format of the input image. If unset, will use the inferred format from the input.
"""
if not isinstance(image, np.ndarray):
- raise ValueError("image must be a numpy array")
+ raise TypeError("image must be a numpy array")
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
@@ -440,7 +444,6 @@ def center_crop(
size: tuple[int, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
- return_numpy: Optional[bool] = None,
) -> np.ndarray:
"""
Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
@@ -461,22 +464,11 @@ def center_crop(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
- return_numpy (`bool`, *optional*):
- Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
- previous ImageFeatureExtractionMixin method.
- - Unset: will return the same type as the input image.
- - `True`: will return a numpy array.
- - `False`: will return a `PIL.Image.Image` object.
Returns:
`np.ndarray`: The cropped image.
"""
requires_backends(center_crop, ["vision"])
- if return_numpy is not None:
- warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
-
- return_numpy = True if return_numpy is None else return_numpy
-
if not isinstance(image, np.ndarray):
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
@@ -528,9 +520,6 @@ def center_crop(
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
- if not return_numpy:
- new_image = to_pil_image(new_image)
-
return new_image
@@ -726,7 +715,7 @@ def _expand_for_data_format(values):
values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
# Add additional padding if there's a batch dimension
- values = (0, *values) if image.ndim == 4 else values
+ values = ((0, 0), *values) if image.ndim == 4 else values
return values
padding = _expand_for_data_format(padding)
@@ -812,37 +801,114 @@ def _cast_tensor_to_float(x):
return x.float()
+def _group_images_by_shape(nested_images, is_nested: bool = False):
+ """Helper function to flatten a single level of nested image structures and group by shape."""
+ grouped_images = defaultdict(list)
+ grouped_images_index = {}
+ nested_images = [nested_images] if not is_nested else nested_images
+ for i, sublist in enumerate(nested_images):
+ for j, image in enumerate(sublist):
+ key = (i, j) if is_nested else j
+ shape = image.shape[1:]
+ grouped_images[shape].append(image)
+ grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1)
+
+ return grouped_images, grouped_images_index
+
+
+def _reconstruct_nested_structure(indices, processed_images):
+ """Helper function to reconstruct a single level nested structure."""
+ # Find the maximum outer index
+ max_outer_idx = max(idx[0] for idx in indices.keys())
+
+ # Create the outer list
+ result = [None] * (max_outer_idx + 1)
+
+ # Group indices by outer index
+ nested_indices = defaultdict(list)
+ for i, j in indices.keys():
+ nested_indices[i].append(j)
+
+ for i in range(max_outer_idx + 1):
+ if i in nested_indices:
+ inner_max_idx = max(nested_indices[i])
+ inner_list = [None] * (inner_max_idx + 1)
+ for j in range(inner_max_idx + 1):
+ if (i, j) in indices:
+ shape, idx = indices[(i, j)]
+ inner_list[j] = processed_images[shape][idx]
+ result[i] = inner_list
+
+ return result
+
+
def group_images_by_shape(
- images: list["ms.Tensor"],
-) -> tuple[dict[tuple[int, int], list["ms.Tensor"]], dict[int, tuple[tuple[int, int], int]]]:
+ images: Union[list["ms.Tensor"], "ms.Tensor"],
+ is_nested: bool = False,
+) -> tuple[dict[tuple[int, int], list["ms.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
"""
Groups images by shape.
Returns a dictionary with the shape as key and a list of images with that shape as value,
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
+
+ The function supports both flat lists of tensors and nested structures.
+ The input must be either all flat or all nested, not a mix of both.
+
+ Args:
+ images (Union[list["ms.Tensor"], "ms.Tensor"]):
+ A list of images or a single tensor
+ disable_grouping (bool):
+ Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
+ This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
+ is_nested (bool, *optional*, defaults to False):
+ Whether the images are nested.
+
+ Returns:
+ tuple[dict[tuple[int, int], list["ms.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
+ - A dictionary with shape as key and list of images with that shape as value
+ - A dictionary mapping original indices to (shape, index) tuples
"""
- grouped_images = {}
- grouped_images_index = {}
- for i, image in enumerate(images):
- shape = image.shape[1:]
- if shape not in grouped_images:
- grouped_images[shape] = []
- grouped_images[shape].append(image)
- grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)
- # stack images with the same shape
- grouped_images = {shape: mint.stack(images, dim=0) for shape, images in grouped_images.items()}
+ # TODO mindone.transformers hasn't supported disable_grouping yet
+
+ # Handle single level nested structure
+ grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested)
+
+ # Stack images with the same shape
+ grouped_images = {shape: mint.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
+
return grouped_images, grouped_images_index
def reorder_images(
- processed_images: dict[tuple[int, int], "ms.Tensor"], grouped_images_index: dict[int, tuple[int, int]]
-) -> list["ms.Tensor"]:
+ processed_images: dict[tuple[int, int], "ms.Tensor"],
+ grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]],
+ is_nested: bool = False,
+) -> Union[list["ms.Tensor"], "ms.Tensor"]:
"""
- Reconstructs a list of images in the original order.
+ Reconstructs images in the original order, preserving the original structure (nested or not).
+ The input structure is either all flat or all nested.
+
+ Args:
+ processed_images (dict[tuple[int, int], "ms.Tensor"]):
+ Dictionary mapping shapes to batched processed images.
+ grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]):
+ Dictionary mapping original indices to (shape, index) tuples.
+ is_nested (bool, *optional*, defaults to False):
+ Whether the images are nested. Cannot be infered from the input, as some processing functions outputs nested images.
+ even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input.
+
+
+ Returns:
+ Union[list["ms.Tensor"], "ms.Tensor"]:
+ Images in the original structure.
"""
- return [
- processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
- for i in range(len(grouped_images_index))
- ]
+ if not is_nested:
+ return [
+ processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
+ for i in range(len(grouped_images_index))
+ ]
+
+ return _reconstruct_nested_structure(grouped_images_index, processed_images)
class NumpyToTensor:
diff --git a/mindone/transformers/image_utils.py b/mindone/transformers/image_utils.py
index 6d8835952d..c099fa9242 100644
--- a/mindone/transformers/image_utils.py
+++ b/mindone/transformers/image_utils.py
@@ -18,24 +18,16 @@
import base64
import os
from collections.abc import Iterable
-from contextlib import redirect_stdout
from dataclasses import dataclass
from io import BytesIO
-from typing import TYPE_CHECKING, Callable, Optional, Union
+from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import requests
from packaging import version
-from transformers import is_av_available
-from transformers.utils import is_cv2_available, is_decord_available, is_yt_dlp_available, logging
-from transformers.utils.constants import ( # noqa: F401
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- IMAGENET_STANDARD_MEAN,
- IMAGENET_STANDARD_STD,
- OPENAI_CLIP_MEAN,
- OPENAI_CLIP_STD,
-)
+from transformers.utils import logging
+
+from mindspore import mint
from .utils import (
ExplicitEnum,
@@ -46,6 +38,14 @@
requires_backends,
to_numpy,
)
+from .utils.constants import ( # noqa: F401
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+)
if is_vision_available():
import PIL.Image
@@ -85,18 +85,6 @@
] # noqa
-VideoInput = Union[
- list["PIL.Image.Image"],
- "np.ndarray",
- "mindspore.Tensor",
- list["np.ndarray"],
- list["mindspore.Tensor"],
- list[list["PIL.Image.Image"]],
- list[list["np.ndarrray"]],
- list[list["mindspore.Tensor"]],
-] # noqa
-
-
class ChannelDimension(ExplicitEnum):
FIRST = "channels_first"
LAST = "channels_last"
@@ -112,14 +100,6 @@ class AnnotionFormat(ExplicitEnum):
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
-@dataclass
-class VideoMetadata:
- total_num_frames: int
- fps: float
- duration: float
- video_backend: str
-
-
AnnotationType = dict[str, Union[int, str, list[dict]]]
@@ -153,6 +133,15 @@ def is_valid_list_of_images(images: list):
return images and all(is_valid_image(image) for image in images)
+def concatenate_list(input_list):
+ if isinstance(input_list[0], list):
+ return [item for sublist in input_list for item in sublist]
+ elif isinstance(input_list[0], np.ndarray):
+ return np.concatenate(input_list, axis=0)
+ elif isinstance(input_list[0], mint.Tensor):
+ return mint.cat(input_list, dim=0)
+
+
def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
@@ -223,13 +212,16 @@ def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
def make_flat_list_of_images(
images: Union[list[ImageInput], ImageInput],
+ expected_ndims: int = 3,
) -> ImageInput:
"""
Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a nested list of images, it is converted to a flat list of images.
Args:
- images (`Union[List[ImageInput], ImageInput]`):
+ images (`Union[list[ImageInput], ImageInput]`):
The input image.
+ expected_ndims (`int`, *optional*, defaults to 3):
+ The expected number of dimensions for a single input image.
Returns:
list: A list of images or a 4d array of images.
"""
@@ -242,15 +234,15 @@ def make_flat_list_of_images(
return [img for img_list in images for img in img_list]
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
- if is_pil_image(images[0]) or images[0].ndim == 3:
+ if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
return images
- if images[0].ndim == 4:
+ if images[0].ndim == expected_ndims + 1:
return [img for img_list in images for img in img_list]
if is_valid_image(images):
- if is_pil_image(images) or images.ndim == 3:
+ if is_pil_image(images) or images.ndim == expected_ndims:
return [images]
- if images.ndim == 4:
+ if images.ndim == expected_ndims + 1:
return list(images)
raise ValueError(f"Could not make a flat list of images from {images}")
@@ -258,12 +250,15 @@ def make_flat_list_of_images(
def make_nested_list_of_images(
images: Union[list[ImageInput], ImageInput],
+ expected_ndims: int = 3,
) -> ImageInput:
"""
Ensure that the output is a nested list of images.
Args:
- images (`Union[List[ImageInput], ImageInput]`):
+ images (`Union[list[ImageInput], ImageInput]`):
The input image.
+ expected_ndims (`int`, *optional*, defaults to 3):
+ The expected number of dimensions for a single input image.
Returns:
list: A list of list of images or a list of 4d array of images.
"""
@@ -277,52 +272,21 @@ def make_nested_list_of_images(
# If it's a list of images, it's a single batch, so convert it to a list of lists
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
- if is_pil_image(images[0]) or images[0].ndim == 3:
+ if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
return [images]
- if images[0].ndim == 4:
+ if images[0].ndim == expected_ndims + 1:
return [list(image) for image in images]
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
- if is_pil_image(images) or images.ndim == 3:
+ if is_pil_image(images) or images.ndim == expected_ndims:
return [[images]]
- if images.ndim == 4:
+ if images.ndim == expected_ndims + 1:
return [list(images)]
raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
-def make_batched_videos(videos) -> VideoInput:
- """
- Ensure that the input is a list of videos.
- Args:
- videos (`VideoInput`):
- Video or videos to turn into a list of videos.
- Returns:
- list: A list of videos.
- """
- if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
- # case 1: nested batch of videos so we flatten it
- if not is_pil_image(videos[0][0]) and videos[0][0].ndim == 4:
- videos = [[video for batch_list in batched_videos for video in batch_list] for batched_videos in videos]
- # case 2: list of videos represented as list of video frames
- return videos
-
- elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
- if is_pil_image(videos[0]) or videos[0].ndim == 3:
- return [videos]
- elif videos[0].ndim == 4:
- return [list(video) for video in videos]
-
- elif is_valid_image(videos):
- if is_pil_image(videos) or videos.ndim == 3:
- return [[videos]]
- elif videos.ndim == 4:
- return [list(videos)]
-
- raise ValueError(f"Could not make batched video from {videos}")
-
-
def to_numpy_array(img) -> np.ndarray:
if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")
@@ -374,12 +338,17 @@ def infer_channel_dimension_format(
first_dim, last_dim = 0, 2
elif image.ndim == 4:
first_dim, last_dim = 1, 3
+ elif image.ndim == 5:
+ first_dim, last_dim = 2, 4
else:
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
logger.warning(
- f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the\
+ [input_data_format]"
+ f"(https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) \
+ parameter to assign the channel dimension."
)
return ChannelDimension.FIRST
elif image.shape[first_dim] in num_channels:
@@ -554,347 +523,6 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image
-def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
- """
- A default sampling function that replicates the logic used in get_uniform_frame_indices,
- while optionally handling `fps` if `num_frames` is not provided.
-
- Args:
- metadata (`VideoMetadata`):
- `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
- num_frames (`int`, *optional*):
- Number of frames to sample uniformly.
- fps (`int`, *optional*):
- Desired frames per second. Takes priority over num_frames if both are provided.
-
- Returns:
- `np.ndarray`: Array of frame indices to sample.
- """
- total_num_frames = metadata.total_num_frames
- video_fps = metadata.fps
-
- # If num_frames is not given but fps is, calculate num_frames from fps
- if num_frames is None and fps is not None:
- num_frames = int(total_num_frames / video_fps * fps)
- if num_frames > total_num_frames:
- raise ValueError(
- f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
- f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
- )
-
- if num_frames is not None:
- indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
- else:
- indices = np.arange(0, total_num_frames, dtype=int)
- return indices
-
-
-def read_video_opencv(
- video_path: str,
- sample_indices_fn: Callable,
- **kwargs,
-):
- """
- Decode a video using the OpenCV backend.
-
- Args:
- video_path (`str`):
- Path to the video file.
- sample_indices_fn (`Callable`):
- A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
- by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
- If not provided, simple uniform sampling with fps is performed.
- Example:
- def sample_indices_fn(metadata, **kwargs):
- return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
-
- Returns:
- Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- - `VideoMetadata` object.
- """
- # Lazy import cv2
- requires_backends(read_video_opencv, ["cv2"])
- import cv2
-
- video = cv2.VideoCapture(video_path)
- total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
- video_fps = video.get(cv2.CAP_PROP_FPS)
- duration = total_num_frames / video_fps if video_fps else 0
- metadata = VideoMetadata(
- total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
- )
- indices = sample_indices_fn(metadata=metadata, **kwargs)
-
- index = 0
- frames = []
- while video.isOpened():
- success, frame = video.read()
- if not success:
- break
- if index in indices:
- height, width, channel = frame.shape
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- frames.append(frame[0:height, 0:width, 0:channel])
- if success:
- index += 1
- if index >= total_num_frames:
- break
-
- video.release()
- metadata.frames_indices = indices
- return np.stack(frames), metadata
-
-
-def read_video_decord(
- video_path: str,
- sample_indices_fn: Optional[Callable] = None,
- **kwargs,
-):
- """
- Decode a video using the Decord backend.
-
- Args:
- video_path (`str`):
- Path to the video file.
- sample_indices_fn (`Callable`, *optional*):
- A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
- by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
- If not provided, simple uniform sampling with fps is performed.
- Example:
- def sample_indices_fn(metadata, **kwargs):
- return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
-
- Returns:
- Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- - `VideoMetadata` object.
- """
- # Lazy import from decord
- requires_backends(read_video_decord, ["decord"])
- from decord import VideoReader, cpu
-
- vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
- video_fps = vr.get_avg_fps()
- total_num_frames = len(vr)
- duration = total_num_frames / video_fps if video_fps else 0
- metadata = VideoMetadata(
- total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
- )
-
- indices = sample_indices_fn(metadata=metadata, **kwargs)
-
- frames = vr.get_batch(indices).asnumpy()
- metadata.frames_indices = indices
- return frames, metadata
-
-
-def read_video_pyav(
- video_path: str,
- sample_indices_fn: Callable,
- **kwargs,
-):
- """
- Decode the video with PyAV decoder.
-
- Args:
- video_path (`str`):
- Path to the video file.
- sample_indices_fn (`Callable`, *optional*):
- A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
- by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
- If not provided, simple uniform sampling with fps is performed.
- Example:
- def sample_indices_fn(metadata, **kwargs):
- return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
-
- Returns:
- Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- - `VideoMetadata` object.
- """
- # Lazy import av
- requires_backends(read_video_pyav, ["av"])
- import av
-
- container = av.open(video_path)
- total_num_frames = container.streams.video[0].frames
- video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
- duration = total_num_frames / video_fps if video_fps else 0
- metadata = VideoMetadata(
- total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
- )
- indices = sample_indices_fn(metadata=metadata, **kwargs)
-
- frames = []
- container.seek(0)
- end_index = indices[-1]
- for i, frame in enumerate(container.decode(video=0)):
- if i > end_index:
- break
- if i >= 0 and i in indices:
- frames.append(frame)
-
- video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
- metadata.frames_indices = indices
- return video, metadata
-
-
-def read_video_mindspore(
- video_path: str,
- sample_indices_fn: Callable,
- **kwargs,
-):
- """
- Decode the video with mindspore.dataset decoder.
-
- Args:
- video_path (`str`):
- Path to the video file.
- sample_indices_fn (`Callable`, *optional*):
- A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
- by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
- If not provided, simple uniform sampling with fps is performed.
- Example:
- def sample_indices_fn(metadata, **kwargs):
- return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
-
- Returns:
- tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- - `VideoMetadata` object.
- """
- video, _, info = ms.dataset.vision.read_video(
- video_path,
- start_pts=0.0,
- end_pts=None,
- pts_unit="sec",
- )
- video_fps = info["video_fps"]
- total_num_frames = video.size(0)
- duration = total_num_frames / video_fps if video_fps else 0
- metadata = VideoMetadata(
- total_num_frames=int(total_num_frames),
- fps=float(video_fps),
- duration=float(duration),
- video_backend="mindspore",
- )
-
- indices = sample_indices_fn(metadata=metadata, **kwargs)
-
- video = video[indices].contiguous().numpy()
- metadata.frames_indices = indices
- return video, metadata
-
-
-VIDEO_DECODERS = {
- "decord": read_video_decord,
- "opencv": read_video_opencv,
- "pyav": read_video_pyav,
- "torchvision": read_video_mindspore,
-}
-
-
-def load_video(
- video: Union[str, "VideoInput"],
- num_frames: Optional[int] = None,
- fps: Optional[int] = None,
- backend: str = "opencv",
- sample_indices_fn: Optional[Callable] = None,
- **kwargs,
-) -> np.array:
- """
- Loads `video` to a numpy array.
-
- Args:
- video (`str` or `VideoInput`):
- The video to convert to the numpy array format. Can be a link to video or local path.
- num_frames (`int`, *optional*):
- Number of frames to sample uniformly. If not passed, the whole video is loaded.
- fps (`int`, *optional*):
- Number of frames to sample per second. Should be passed only when `num_frames=None`.
- If not specified and `num_frames==None`, all frames are sampled.
- backend (`str`, *optional*, defaults to `"opencv"`):
- The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
- sample_indices_fn (`Callable`, *optional*):
- A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
- by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
- If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
- The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
- indices at which the video should be sampled. For example:
-
- Example:
- def sample_indices_fn(metadata, **kwargs):
- return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
-
- Returns:
- tuple[`np.array`, Dict]: A tuple containing:
- - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- - Metadata dictionary.
- """
-
- # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
- if fps is not None and num_frames is not None and sample_indices_fn is None:
- raise ValueError(
- "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
- )
-
- # If user didn't pass a sampling function, create one on the fly with default logic
- if sample_indices_fn is None:
-
- def sample_indices_fn_func(metadata, **fn_kwargs):
- return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
-
- sample_indices_fn = sample_indices_fn_func
-
- if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
- if not is_yt_dlp_available():
- raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
- # Lazy import from yt_dlp
- requires_backends(load_video, ["yt_dlp"])
- from yt_dlp import YoutubeDL
-
- buffer = BytesIO()
- with redirect_stdout(buffer), YoutubeDL() as f:
- f.download([video])
- bytes_obj = buffer.getvalue()
- file_obj = BytesIO(bytes_obj)
- elif video.startswith("http://") or video.startswith("https://"):
- file_obj = BytesIO(requests.get(video).content)
- elif os.path.isfile(video):
- file_obj = video
- elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
- file_obj = None
- else:
- raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
-
- # can also load with decord, but not cv2/torchvision
- # both will fail in case of url links
- video_is_url = video.startswith("http://") or video.startswith("https://")
- if video_is_url and backend in ["opencv", "mindspore"]:
- raise ValueError(
- "If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
- )
-
- if file_obj is None:
- return video
-
- if (
- (not is_decord_available() and backend == "decord")
- or (not is_av_available() and backend == "pyav")
- or (not is_cv2_available() and backend == "opencv")
- or (not is_mindspore_available() and backend == "torchvision")
- ):
- raise ImportError(
- f"You chose backend={backend} for loading the video but the required library is not found in your environment "
- f"Make sure to install {backend} before loading the video."
- )
-
- video_decoder = VIDEO_DECODERS[backend]
- video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
- return video, metadata
-
-
def load_images(
images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
@@ -1147,7 +775,7 @@ def resize(self, image, size, resample=None, default_to_square=True, max_size=No
default_to_square (`bool`, *optional*, defaults to `True`):
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
square (`size`,`size`). If set to `False`, will replicate
- [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
+ [`torchvision.transforms.Resize`](https://pymint.org/vision/stable/transforms.html#torchvision.transforms.Resize)
with support for resizing only the smallest edge and providing an optional `max_size`.
max_size (`int`, *optional*, defaults to `None`):
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
@@ -1351,12 +979,12 @@ class SizeDict:
Hashable dictionary to store image size information.
"""
- height: int = None
- width: int = None
- longest_edge: int = None
- shortest_edge: int = None
- max_height: int = None
- max_width: int = None
+ height: Optional[int] = None
+ width: Optional[int] = None
+ longest_edge: Optional[int] = None
+ shortest_edge: Optional[int] = None
+ max_height: Optional[int] = None
+ max_width: Optional[int] = None
def __getitem__(self, key):
if hasattr(self, key):
diff --git a/mindone/transformers/masking_utils.py b/mindone/transformers/masking_utils.py
index 25fb7c64eb..e7329a447b 100644
--- a/mindone/transformers/masking_utils.py
+++ b/mindone/transformers/masking_utils.py
@@ -15,11 +15,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
from typing import Callable, Optional, Union
from transformers.configuration_utils import PretrainedConfig
import mindspore as ms
+import mindspore.mint.nn.functional as F
from mindspore import mint
from .cache_utils import Cache
@@ -76,6 +78,18 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return inner_mask
+def chunked_overlay(chunk_size: int) -> Callable:
+ """
+ This is an overlay depicting a chuned attention pattern. Add it on top of a causal mask for a proper chunked
+ attention mask.
+ """
+
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
+ return kv_idx // chunk_size == q_idx // chunk_size
+
+ return inner_mask
+
+
def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
"""
This return the mask_function function to create a sliding window mask.
@@ -83,6 +97,50 @@ def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
return and_masks(sliding_window_overlay(sliding_window), causal_mask_function)
+def chunked_causal_mask_function(chunk_size: int) -> Callable:
+ """
+ This return the mask_function function to create a chunked attention mask.
+ """
+ return and_masks(chunked_overlay(chunk_size), causal_mask_function)
+
+
+def padding_mask_function(padding_mask: ms.Tensor) -> Callable:
+ """
+ This return the mask_function function corresponding to a 2D padding mask.
+ """
+
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
+ # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
+ # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
+ # vectorizable on accelerator devices
+ return padding_mask[batch_idx, kv_idx]
+
+ return inner_mask
+
+
+def packed_sequence_mask_function(packed_sequence_mask: ms.Tensor) -> Callable:
+ """
+ This return the mask_function function corresponding to a 2D packed sequence mask.
+ """
+
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
+ return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]
+
+ return inner_mask
+
+
+def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
+ """
+ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
+ not start and end indices.
+ """
+
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
+ return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset)
+
+ return inner_mask
+
+
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
"""
Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
@@ -132,6 +190,138 @@ def prepare_padding_mask(
return local_padding_mask
+def sdpa_mask_recent_torch(
+ batch_size: int,
+ cache_position: ms.Tensor,
+ kv_length: int,
+ kv_offset: int = 0,
+ mask_function: Callable = causal_mask_function,
+ attention_mask: Optional[ms.Tensor] = None,
+ local_size: Optional[int] = None,
+ allow_is_causal_skip: bool = True,
+ **kwargs,
+) -> Optional[ms.Tensor]:
+ """
+ Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
+ the element should take part in the attention computation, and False that it should not.
+ This function can only be used with torch>=2.5, as the context manager is otherwise not available.
+
+ Args:
+ batch_size (`int`):
+ The batch size of the input sequence.
+ cache_position (`ms.Tensor`):
+ A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
+ kv_length (`int`):
+ The size that the key and value states will have during the attention computation.
+ kv_offset (`int`, optional):
+ An optional offset to indicate at which first position the key and values states will refer to.
+ mask_function (`Callable`):
+ The mask factory function describing the mask pattern.
+ attention_mask (`ms.Tensor`, optional):
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
+ local_size (`int`, optional):
+ The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
+ to try to skip mask creation if possible.
+ allow_is_causal_skip (`bool`, optional):
+ Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
+ `torch.sdpa` instead. Default to `True`.
+ allow_torch_fix (`bool`, optional):
+ Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
+ versions. We need an arg to skip it when using eager. By default `True`.
+
+
+ ## Creating a simple causal mask:
+
+ To create the following causal mask:
+
+ 0 ■ ⬚ ⬚ ⬚ ⬚
+ 1 ■ ■ ⬚ ⬚ ⬚
+ 2 ■ ■ ■ ⬚ ⬚
+ 3 ■ ■ ■ ■ ⬚
+ 4 ■ ■ ■ ■ ■
+
+ You can do
+
+ ```python
+ >>> create_4d_causal_mask(batch_size=1, cache_position=mint.arange(5), kv_length=5)
+ >>> tensor([[[[ True, False, False, False, False],
+ [ True, True, False, False, False],
+ [ True, True, True, False, False],
+ [ True, True, True, True, False],
+ [ True, True, True, True, True]]]])
+ ```
+
+ ## Creating a sliding window mask:
+
+ To create the following sliding window mask (`sliding_window=3`):
+
+ 0 ■ ⬚ ⬚ ⬚ ⬚
+ 1 ■ ■ ⬚ ⬚ ⬚
+ 2 ■ ■ ■ ⬚ ⬚
+ 3 ⬚ ■ ■ ■ ⬚
+ 4 ⬚ ⬚ ■ ■ ■
+
+ You can do
+
+ ```python
+ >>> create_4d_causal_mask(batch_size=1, cache_position=mint.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
+ >>> tensor([[[[ True, False, False, False, False],
+ [ True, True, False, False, False],
+ [ True, True, True, False, False],
+ [False, True, True, True, False],
+ [False, False, True, True, True]]]])
+ ```
+
+ ## Creating a chunked attention mask
+
+ To create the following chunked attention mask (`chunk_size=3`):
+
+ 0 ■ ⬚ ⬚ ⬚ ⬚
+ 1 ■ ■ ⬚ ⬚ ⬚
+ 2 ■ ■ ■ ⬚ ⬚
+ 3 ⬚ ⬚ ⬚ ■ ⬚
+ 4 ⬚ ⬚ ⬚ ■ ■
+
+ You can do
+
+ ```python
+ >>> create_4d_causal_mask(batch_size=1, cache_position=mint.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3))
+ >>> tensor([[[[ True, False, False, False, False],
+ [ True, True, False, False, False],
+ [ True, True, True, False, False],
+ [False, False, False, True, False],
+ [False, False, False, True, True]]]])
+ ```
+
+ """
+ q_length = cache_position.shape[0]
+ # Potentially pad the 2D mask, and slice it correctly
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
+
+ # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
+ if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
+ return None
+
+ # Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
+ # but without data-dependent slicing (i.e. torch.compile friendly)
+ kv_arange = mint.arange(kv_length, device=cache_position.device)
+ kv_arange += kv_offset
+
+ # Potentially add the padding 2D mask
+ if padding_mask is not None:
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
+
+ batch_arange = mint.arange(batch_size, device=cache_position.device)
+ head_arange = mint.arange(1, device=cache_position.device)
+ # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
+ # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
+ # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
+ # with TransformGetItemToIndex():
+ causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
+
+ return causal_mask
+
+
def sdpa_mask_older_torch(
batch_size: int,
cache_position: ms.Tensor,
@@ -156,7 +346,7 @@ def sdpa_mask_older_torch(
Args:
batch_size (`int`):
The batch size of the input sequence.
- cache_position (`torch.Tensor`):
+ cache_position (`ms.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
@@ -164,7 +354,7 @@ def sdpa_mask_older_torch(
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
- attention_mask (`torch.Tensor`, optional):
+ attention_mask (`ms.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
@@ -184,7 +374,7 @@ def sdpa_mask_older_torch(
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
return None
- # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
+ # Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = mint.arange(kv_length)
kv_arange += kv_offset
@@ -246,7 +436,7 @@ def _ignore_causal_mask_sdpa(
# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
# (especially mask_function indexing a tensor, such as the padding mask function)
-sdpa_mask = sdpa_mask_older_torch
+sdpa_mask = sdpa_mask_older_torch # TODO: use sdpa_mask_recent_torch orsdpa_mask_older_torch?
def eager_mask(
@@ -267,7 +457,7 @@ def eager_mask(
Args:
batch_size (`int`):
The batch size of the input sequence.
- cache_position (`torch.Tensor`):
+ cache_position (`ms.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
@@ -275,7 +465,7 @@ def eager_mask(
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
- attention_mask (`torch.Tensor`, optional):
+ attention_mask (`ms.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
dtype (`torch.dtype`, optional):
The dtype to use for the mask. By default, `torch.float32`.
@@ -365,12 +555,40 @@ class AttentionMaskInterface(GeneralInterface):
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
+def find_packed_sequence_indices(position_ids: ms.Tensor) -> ms.Tensor:
+ """
+ Find the indices of the sequence to which each new query token in the sequence belongs when using packed
+ tensor format (i.e. several sequences packed in the same batch dimension).
+
+ Args:
+ position_ids (`ms.Tensor`)
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
+
+ Returns:
+ A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
+ pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
+ """
+ # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
+ # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
+ # gives exactly the sequence indices
+ # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
+ # cannot be part of the end of the first batch dim and the start of the 2nd one for example
+ first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
+ position_diff = mint.diff(position_ids, prepend=first_dummy_value, dim=-1)
+ packed_sequence_mask = (position_diff != 1).cumsum(-1)
+
+ # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
+ # but it causes issues with export
+ return packed_sequence_mask
+
+
def _preprocess_mask_arguments(
config: PretrainedConfig,
input_embeds: ms.Tensor,
attention_mask: Optional[Union[ms.Tensor, BlockMask]],
cache_position: ms.Tensor,
past_key_values: Optional[Cache],
+ position_ids: Optional[ms.Tensor],
layer_idx: Optional[int],
) -> tuple[bool, Optional[Union[ms.Tensor, BlockMask]], int, int]:
"""
@@ -390,6 +608,8 @@ def _preprocess_mask_arguments(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
+ position_ids (`ms.Tensor`, optional)
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
layer_idx (`int`, optional):
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
length and offset. Indeed, for hybrid caches, different layers may return different lengths.
@@ -399,6 +619,9 @@ def _preprocess_mask_arguments(
Whether we should early exit mask creation, and return the mask as-is.
attention_mask (`ms.Tensor` or `BlockMask` or `None`):
The attention mask to either return immediately, or to use in downstream mask creation.
+ packed_sequence_mask (`ms.Tensor`, optional):
+ In case we detected packed sequence format, this is a tensor where each similar integer indicates that
+ the tokens belong to the same sequence.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`):
@@ -414,7 +637,7 @@ def _preprocess_mask_arguments(
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
- return True, None, None, None
+ return True, None, None, None, None
# Move the mask to correct device, and potentially switch dtype for efficiency
if attention_mask is not None and attention_mask.ndim == 2:
@@ -426,8 +649,17 @@ def _preprocess_mask_arguments(
# Otherwise, the sizes are simply the input sizes
else:
kv_length, kv_offset = input_embeds.shape[1], 0
+ # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
+ # and we don't have past_key_values, i.e. generally a training setup)
+ packed_sequence_mask = None
+ if position_ids is not None and attention_mask is None and past_key_values is None:
+ batch_size = input_embeds.shape[0]
+ # The position ids are sometimes just unsqueezed, without being expanded
+ if batch_size != position_ids.shape[0]:
+ position_ids = position_ids.expand(batch_size, -1)
+ packed_sequence_mask = find_packed_sequence_indices(position_ids)
- return False, attention_mask, kv_length, kv_offset
+ return False, attention_mask, packed_sequence_mask, kv_length, kv_offset
def create_causal_mask(
@@ -436,6 +668,7 @@ def create_causal_mask(
attention_mask: Optional[ms.Tensor],
cache_position: ms.Tensor,
past_key_values: Optional[Cache],
+ position_ids: Optional[ms.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[ms.Tensor, BlockMask]]:
@@ -457,6 +690,8 @@ def create_causal_mask(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
+ position_ids (`ms.Tensor`, optional)
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -470,8 +705,8 @@ def create_causal_mask(
else:
layer_idx = 0
- early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
- config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
+ early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
+ config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
)
if early_exit:
return attention_mask
@@ -484,6 +719,10 @@ def create_causal_mask(
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
+ # If we detected packing format
+ if packed_sequence_mask is not None:
+ mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
+ allow_is_causal_skip = False
# Allow slight deviations from causal mask
if or_mask_function is not None:
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
@@ -513,6 +752,7 @@ def create_sliding_window_causal_mask(
attention_mask: Optional[ms.Tensor],
cache_position: ms.Tensor,
past_key_values: Optional[Cache],
+ position_ids: Optional[ms.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[ms.Tensor, BlockMask]]:
@@ -535,6 +775,8 @@ def create_sliding_window_causal_mask(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
+ position_ids (`ms.Tensor`, optional)
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
@@ -548,8 +790,8 @@ def create_sliding_window_causal_mask(
else:
layer_idx = 0
- early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
- config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
+ early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
+ config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
)
if early_exit:
return attention_mask
@@ -565,10 +807,17 @@ def create_sliding_window_causal_mask(
# Do not allow skip if we are compiling (this is to match BC)
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
-
+ # If we detected packing format
+ if packed_sequence_mask is not None:
+ mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
+ allow_is_causal_skip = False
# Allow slight deviations from sliding causal mask
- if or_mask_function is not None or and_mask_function is not None:
- raise NotImplementedError("`or_mask_function` or `and_mask_function` arguments are not supported yet.")
+ if or_mask_function is not None:
+ mask_factory_function = or_masks(mask_factory_function, or_mask_function)
+ allow_is_causal_skip = False
+ if and_mask_function is not None:
+ mask_factory_function = and_masks(mask_factory_function, and_mask_function)
+ allow_is_causal_skip = False
# We now create the mask
causal_mask = mask_interface(
@@ -586,7 +835,300 @@ def create_sliding_window_causal_mask(
return causal_mask
+def create_chunked_causal_mask(
+ config: PretrainedConfig,
+ input_embeds: ms.Tensor,
+ attention_mask: Optional[ms.Tensor],
+ cache_position: ms.Tensor,
+ past_key_values: Optional[Cache],
+ position_ids: Optional[ms.Tensor] = None,
+ or_mask_function: Optional[Callable] = None,
+ and_mask_function: Optional[Callable] = None,
+) -> Optional[Union[ms.Tensor, BlockMask]]:
+ """
+ Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
+ of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this
+ function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the
+ `modeling_xxx.py` files).
+
+ Args:
+ config (`PretrainedConfig`):
+ The model config.
+ input_embeds (`ms.Tensor`):
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
+ batch size, query length and dtype.
+ attention_mask (`ms.Tensor`, optional):
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
+ It can also be an already prepared 4D mask, in which case it is returned as-is.
+ cache_position (`ms.Tensor`):
+ A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
+ past_key_values (`Cache`, optional):
+ The past key values, if we use a cache.
+ position_ids (`ms.Tensor`, optional)
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
+ or_mask_function (`Callable`, optional):
+ An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
+ useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
+ and_mask_function (`Callable`, optional):
+ An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
+ useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
+ """
+ # If we have an HybridCache structure, here we want to create the mask for the sliding layers
+ if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
+ layer_idx = past_key_values.is_sliding.index(True)
+ else:
+ layer_idx = 0
+
+ early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
+ config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
+ )
+ if early_exit:
+ return attention_mask
+
+ chunk_size = getattr(config, "attention_chunk_size", None)
+ if chunk_size is None:
+ raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set")
+
+ # Raise if using chunked attention on context too large with FA2
+ if config._attn_implementation == "flash_attention_2" and kv_length + kv_offset > chunk_size:
+ raise ValueError(
+ "Flash attention 2 cannot handle chunked attention, and the key-value length is larger than the chunk size so the "
+ "chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model"
+ )
+
+ batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
+ mask_factory_function = chunked_causal_mask_function(chunk_size)
+ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
+
+ # Do not allow skip if we are compiling (this is to match BC)
+ # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
+ allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
+
+ # If we detected packing format
+ if packed_sequence_mask is not None:
+ mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
+ allow_is_causal_skip = False
+
+ # Allow slight deviations from chunked causal mask
+ if or_mask_function is not None:
+ mask_factory_function = or_masks(mask_factory_function, or_mask_function)
+ allow_is_causal_skip = False
+ if and_mask_function is not None:
+ mask_factory_function = and_masks(mask_factory_function, and_mask_function)
+ allow_is_causal_skip = False
+
+ # We now create the mask
+ causal_mask = mask_interface(
+ batch_size=batch_size,
+ cache_position=cache_position,
+ kv_length=kv_length,
+ kv_offset=kv_offset,
+ mask_function=mask_factory_function,
+ attention_mask=attention_mask,
+ allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
+ local_size=chunk_size, # Additional kwarg for sdpa
+ dtype=dtype, # Additional kwarg for eager
+ config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
+ )
+ return causal_mask
+
+
LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {
"full_attention": create_causal_mask,
"sliding_attention": create_sliding_window_causal_mask,
+ "chunked_attention": create_chunked_causal_mask,
}
+
+
+def create_masks_for_generate(
+ config: PretrainedConfig,
+ input_embeds: ms.Tensor,
+ attention_mask: Optional[ms.Tensor],
+ cache_position: ms.Tensor,
+ past_key_values: Optional[Cache],
+ position_ids: Optional[ms.Tensor] = None,
+ or_mask_function: Optional[Callable] = None,
+ and_mask_function: Optional[Callable] = None,
+ **kwargs,
+):
+ """
+ This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in `generate` in order
+ to easily create the masks in advance, when we compile the forwards with Static caches.
+
+ Args:
+ config (`PretrainedConfig`):
+ The model config.
+ input_embeds (`ms.Tensor`):
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
+ batch size, query length and dtype.
+ attention_mask (`ms.Tensor`, optional):
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
+ It can also be an already prepared 4D mask, in which case it is returned as-is.
+ cache_position (`ms.Tensor`):
+ A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
+ past_key_values (`Cache`, optional):
+ The past key values, if we use a cache.
+ position_ids (`ms.Tensor`, optional)
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
+ or_mask_function (`Callable`, optional):
+ An optional mask function to combine with the other mask function (by doing the union of both). This is
+ useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
+ and_mask_function (`Callable`, optional):
+ An optional mask function to combine with the other mask function (by doing the intersection of both). This is
+ useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
+ """
+ # The attribute reside in the text config for composite models
+ effective_config = config.get_text_config()
+ # Prepare the mask args
+ mask_kwargs = {
+ "config": effective_config,
+ "input_embeds": input_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ "or_mask_function": or_mask_function,
+ "and_mask_function": and_mask_function,
+ }
+
+ # If the attribute exist, we need several masks
+ if hasattr(effective_config, "layer_types"):
+ causal_masks = {}
+ for layer_pattern in set(effective_config.layer_types):
+ causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs)
+ return causal_masks
+ # In this case, all layers are sliding
+ elif getattr(effective_config, "sliding_window", None) is not None:
+ return create_sliding_window_causal_mask(**mask_kwargs)
+ # In this case, all layers are chunked
+ elif getattr(effective_config, "attention_chunk_size", None) is not None:
+ return create_chunked_causal_mask(**mask_kwargs)
+ # All layers use standard causal attention
+ return create_causal_mask(**mask_kwargs)
+
+
+# Below are utilities to pretty-print the different masks
+# Print the matrix with words as row labels
+GREEN = "\033[92m"
+YELLOW = "\033[93m"
+RESET = "\033[0m"
+BLACK_SQUARE = "■"
+WHITE_SQUARE = "⬚"
+GREY_SQUARE = "∙"
+LOW_TRIANGLE = "⬕"
+UPPER_TRIANGLE = "⬔"
+
+
+def get_style(style):
+ if style == "majong":
+ BLACK_SQUARE = "🀞" # Full block (represents "on" or active)
+ BLACK_SQUARE = "🀙" # Full block (represents "on" or active)
+ WHITE_SQUARE = "🀆" # "▒" # Light shade (represents "off" or inactive)
+ LOW_TRIANGLE = "🀛" # Lower left triangle (stylized indication)
+ UPPER_TRIANGLE = "🀛" # Upper left triangle (stylized indication)
+ else:
+ BLACK_SQUARE = "█" # Full block (represents "on" or active)
+ WHITE_SQUARE = "░" # "▒" # Light shade (represents "off" or inactive)
+ LOW_TRIANGLE = "▙" # Lower left triangle (stylized indication))
+ UPPER_TRIANGLE = "▜" # Upper left triangle (stylized indication)
+
+ return BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE
+
+
+# LOW_TRIANGLE = UPPER_TRIANGLE = "⟍" # Upper right triangle (stylized indication)
+
+YELLOW_SQUARE = f"{YELLOW}{BLACK_SQUARE}{RESET}"
+GREEN_SQUARE = f"{GREEN}{BLACK_SQUARE}{RESET}"
+
+
+def tensor_to_mask_visual(original_tensor: ms.Tensor, grid_size=(20, 40), style="majong") -> str:
+ BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE = get_style(style)
+ h, w = original_tensor.shape
+ max_h, max_w = grid_size
+ if not (h < max_h and w < max_w):
+ # Preserve aspect ratio within max grid size
+ aspect_ratio = 2 * w / h
+ if aspect_ratio > 1:
+ w = max_w
+ h = min(max_h, max(1, round(max_w / aspect_ratio)))
+ else:
+ h = max_h
+ w = max(1, round(max_h * aspect_ratio))
+
+ # Step 1: Rescale tensor by average pooling
+ tensor = original_tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
+ tensor = F.adaptive_avg_pool2d(tensor, output_size=(h, w))[0, 0] # Remove extra dims
+ else:
+ tensor = original_tensor
+
+ # Step 3: Build the string representation
+ result = []
+ for i in range(h):
+ row = ""
+ for j in range(w):
+ if tensor[i, j] == 1:
+ row += BLACK_SQUARE
+ elif tensor[i, j] == 0:
+ row += WHITE_SQUARE
+ else:
+ if j > 0:
+ if tensor[i, j - 1] == 1:
+ row += LOW_TRIANGLE
+ elif tensor[i, j - 1] == 0:
+ row += UPPER_TRIANGLE
+ else:
+ row += BLACK_SQUARE if tensor[i, j] == 1 else WHITE_SQUARE
+ else:
+ row += (
+ BLACK_SQUARE
+ if tensor[i, j] == 1
+ else (
+ WHITE_SQUARE
+ if tensor[i, j] == 0
+ else (UPPER_TRIANGLE if tensor[i, j + 1] == 1 else LOW_TRIANGLE)
+ )
+ )
+ result.append(row)
+
+ return "\n".join(result)
+
+
+class AttentionMask(ms.Tensor):
+ def __new__(cls, data, style=None):
+ # Create a new instance of AttentionMask as a Tensor
+ cls.style = style
+ return ms.Tensor._make_subclass(cls, data, require_grad=False)
+
+ def __init__(self, data):
+ # You can initialize any additional metadata here if needed
+ pass
+
+ def to_string(self, grid_size=(20, 40), limit=4):
+ """Returns a string representation of the block mask."""
+ dense_mask = self
+ *batch_dims, num_rows, num_cols = dense_mask.shape
+ total_vis = []
+
+ for idx, batch_idx in enumerate(itertools.product(*[range(i) for i in batch_dims])):
+ if idx == limit:
+ total_vis.append("...")
+ total_vis.append("To print out more, set AttentionMask.to_string(limit=N)")
+ total_vis.append("You can also index (AttentionMask[batch, head]) to choose a specific batch or head")
+ break
+ block_vis = tensor_to_mask_visual(dense_mask[batch_idx], grid_size=grid_size, style=self.style)
+ total_vis.append(block_vis)
+
+ total_vis.append(f"ms.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})")
+ return "\n".join(total_vis)
+
+ def __repr__(self):
+ return self.to_string()
+
+ def __str__(self):
+ return self.to_string()
+
+ @classmethod
+ def from_tensor(cls, tensor: ms.Tensor, style: Optional[str] = None) -> "AttentionMask":
+ res = cls(tensor)
+ res.style = style
+ return res
diff --git a/mindone/transformers/modeling_attn_mask_utils.py b/mindone/transformers/modeling_attn_mask_utils.py
index f2bf12136b..bb6a21cb88 100644
--- a/mindone/transformers/modeling_attn_mask_utils.py
+++ b/mindone/transformers/modeling_attn_mask_utils.py
@@ -14,6 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
+`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
+and will be removed in the future.
+"""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
diff --git a/mindone/transformers/modeling_layers.py b/mindone/transformers/modeling_layers.py
new file mode 100644
index 0000000000..bff31b14ac
--- /dev/null
+++ b/mindone/transformers/modeling_layers.py
@@ -0,0 +1,263 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# This code is adapted from https://github.com/huggingface/transformers
+# with modifications to run transformers on mindspore.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from abc import ABC
+from typing import Optional
+
+from transformers.utils import auto_docstring, can_return_tuple
+
+import mindspore as ms
+import mindspore.nn as nn
+from mindspore import mint
+
+from .cache_utils import Cache
+from .modeling_outputs import (
+ BaseModelOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from .models.auto import AutoModel
+from .processing_utils import Unpack
+from .utils import TransformersKwargs, logging
+
+logger = logging.get_logger(__name__)
+
+
+class GradientCheckpointingLayer(nn.Cell):
+ """Base class for layers with gradient checkpointing.
+
+ This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
+ (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
+ enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
+
+ Important:
+
+ When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
+ must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
+
+ Example:
+
+ ```python
+ >>> # Correct - hidden_states passed as positional arg
+ >>> out = self.layer(hidden_states, attention_mask=attention_mask)
+
+ >>> # Incorrect - hidden_states passed as keyword arg
+ >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
+ ```
+ """
+
+ gradient_checkpointing = False
+
+ def __call__(self, *args, **kwargs):
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+ return super().__call__(*args, **kwargs)
+
+
+@auto_docstring
+class GenericForSequenceClassification(ABC):
+ base_model_prefix = "model"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
+ setattr(self, self.base_model_prefix, AutoModel.from_config(config))
+ self.score = mint.nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> SequenceClassifierOutputWithPast:
+ transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ hidden_states = transformer_outputs.last_hidden_state
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(ms.int32)
+ token_indices = mint.arange(input_ids.shape[-1], dtype=ms.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[mint.arange(batch_size), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GenericForQuestionAnswering(ABC):
+ base_model_prefix = "model"
+
+ def __init__(self, config):
+ super().__init__(config)
+ # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
+ setattr(self, self.base_model_prefix, AutoModel.from_config(config))
+ self.qa_outputs = mint.nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return getattr(self, self.base_model_prefix).embed_tokens
+
+ def set_input_embeddings(self, value):
+ getattr(self, self.base_model_prefix).embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ start_positions: Optional[ms.Tensor] = None,
+ end_positions: Optional[ms.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> QuestionAnsweringModelOutput:
+ outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ loss = None
+ if start_positions is not None and end_positions is not None:
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
+
+ return QuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring
+class GenericForTokenClassification(ABC):
+ base_model_prefix = "model"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
+ setattr(self, self.base_model_prefix, AutoModel.from_config(config))
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs,
+ ) -> TokenClassifierOutput:
+ outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ sequence_output = outputs.last_hidden_state
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.config)
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/mindone/transformers/modeling_outputs.py b/mindone/transformers/modeling_outputs.py
index 252f1a04be..4a3f4ebd8e 100644
--- a/mindone/transformers/modeling_outputs.py
+++ b/mindone/transformers/modeling_outputs.py
@@ -22,6 +22,8 @@
import mindspore as ms
+from .cache_utils import Cache, EncoderDecoderCache
+
@dataclass
class BaseModelOutput(ModelOutput):
@@ -44,7 +46,7 @@ class BaseModelOutput(ModelOutput):
heads.
"""
- last_hidden_state: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -64,7 +66,7 @@ class BaseModelOutputWithNoAttention(ModelOutput):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
- last_hidden_state: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
@@ -94,8 +96,8 @@ class BaseModelOutputWithPooling(ModelOutput):
heads.
"""
- last_hidden_state: ms.Tensor = None
- pooler_output: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ pooler_output: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -117,8 +119,8 @@ class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
- last_hidden_state: ms.Tensor = None
- pooler_output: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ pooler_output: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
@@ -133,11 +135,8 @@ class BaseModelOutputWithPast(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
- encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
@@ -155,8 +154,8 @@ class BaseModelOutputWithPast(ModelOutput):
heads.
"""
- last_hidden_state: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -189,7 +188,7 @@ class BaseModelOutputWithCrossAttentions(ModelOutput):
weighted average in the cross-attention heads.
"""
- last_hidden_state: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -226,21 +225,18 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
- encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
- last_hidden_state: ms.Tensor = None
- pooler_output: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ pooler_output: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ past_key_values: Optional[Cache] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -256,11 +252,8 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
- encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
@@ -302,9 +295,8 @@ class MoECausalLMOutputWithPast(ModelOutput):
Language modeling loss (for next-token prediction).
logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
@@ -331,12 +323,12 @@ class MoECausalLMOutputWithPast(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
- z_loss: ms.Tensor = None
- aux_loss: ms.Tensor = None
+ z_loss: Optional[ms.Tensor] = None
+ aux_loss: Optional[ms.Tensor] = None
router_logits: Optional[Tuple[ms.Tensor]] = None
@@ -367,7 +359,7 @@ class MoEModelOutput(ModelOutput):
loss and the z_loss for Mixture of Experts models.
"""
- last_hidden_state: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
router_probs: Optional[Tuple[ms.Tensor]] = None
@@ -381,11 +373,8 @@ class MoeModelOutputWithPast(ModelOutput):
Args:
last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
- encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
@@ -409,8 +398,8 @@ class MoeModelOutputWithPast(ModelOutput):
loss for Mixture of Experts models.
"""
- last_hidden_state: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
router_logits: Optional[Tuple[ms.Tensor]] = None
@@ -438,9 +427,8 @@ class MoeCausalLMOutputWithPast(ModelOutput):
Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
loss for Mixture of Experts models.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
@@ -459,8 +447,8 @@ class MoeCausalLMOutputWithPast(ModelOutput):
loss: Optional[ms.Tensor] = None
aux_loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
router_logits: Optional[Tuple[ms.Tensor]] = None
@@ -478,11 +466,8 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
- encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
@@ -572,8 +557,8 @@ class Seq2SeqModelOutput(ModelOutput):
self-attention heads.
"""
- last_hidden_state: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -594,10 +579,9 @@ class Seq2SeqMoEModelOutput(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
@@ -642,8 +626,8 @@ class Seq2SeqMoEModelOutput(ModelOutput):
modules.
"""
- last_hidden_state: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
decoder_router_logits: Optional[Tuple[ms.Tensor]] = None
@@ -678,7 +662,7 @@ class CausalLMOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -693,9 +677,8 @@ class CausalLMOutputWithPast(ModelOutput):
Language modeling loss (for next-token prediction).
logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
@@ -713,8 +696,8 @@ class CausalLMOutputWithPast(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -746,18 +729,16 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `ms.Tensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
- value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
- setting. Only relevant if `config.is_decoder = True`.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -773,9 +754,8 @@ class SequenceClassifierOutputWithPast(ModelOutput):
Classification (or regression if config.num_labels==1) loss.
logits (`ms.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
@@ -793,8 +773,8 @@ class SequenceClassifierOutputWithPast(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -823,7 +803,7 @@ class MaskedLMOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -838,10 +818,9 @@ class Seq2SeqLMOutput(ModelOutput):
Language modeling loss.
logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
@@ -878,8 +857,8 @@ class Seq2SeqLMOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -898,10 +877,9 @@ class Seq2SeqMoEOutput(ModelOutput):
Language modeling loss.
logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
@@ -947,12 +925,12 @@ class Seq2SeqMoEOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- encoder_z_loss: ms.Tensor = None
- decoder_z_loss: ms.Tensor = None
- encoder_aux_loss: ms.Tensor = None
- decoder_aux_loss: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ encoder_z_loss: Optional[ms.Tensor] = None
+ decoder_z_loss: Optional[ms.Tensor] = None
+ encoder_aux_loss: Optional[ms.Tensor] = None
+ decoder_aux_loss: Optional[ms.Tensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
decoder_router_logits: Optional[Tuple[ms.Tensor]] = None
@@ -988,7 +966,7 @@ class NextSentencePredictorOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1017,7 +995,7 @@ class SequenceClassifierOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1032,10 +1010,9 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
Classification (or regression if config.num_labels==1) loss.
logits (`ms.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
@@ -1072,8 +1049,8 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1108,7 +1085,7 @@ class MultipleChoiceModelOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1137,7 +1114,7 @@ class TokenClassifierOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1168,8 +1145,8 @@ class QuestionAnsweringModelOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- start_logits: ms.Tensor = None
- end_logits: ms.Tensor = None
+ start_logits: Optional[ms.Tensor] = None
+ end_logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1177,58 +1154,59 @@ class QuestionAnsweringModelOutput(ModelOutput):
@dataclass
class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
- Base class for outputs of sequence-to-sequence question answering models.
-
- Args:
- loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
- start_logits (`ms.Tensor` of shape `(batch_size, sequence_length)`):
- Span-start scores (before SoftMax).
- end_logits (`ms.Tensor` of shape `(batch_size, sequence_length)`):
- Span-end scores (before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
- decoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
- decoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- cross_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
- weighted average in the cross-attention heads.
- encoder_last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder of the model.
- encoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
- encoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
+ Base class for outputs of sequence-to-sequence question answering models.
+
+ Args:
+ loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+ ]
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or
+ when `config.output_hidden_states=True`):
+ Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or
+ when `config.output_hidden_states=True`):
+ Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
"""
loss: Optional[ms.Tensor] = None
- start_logits: ms.Tensor = None
- end_logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ start_logits: Optional[ms.Tensor] = None
+ end_logits: Optional[ms.Tensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1270,7 +1248,7 @@ class SemanticSegmenterOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1349,7 +1327,7 @@ class DepthEstimatorOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- predicted_depth: ms.Tensor = None
+ predicted_depth: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1377,7 +1355,7 @@ class ImageSuperResolutionOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- reconstruction: ms.Tensor = None
+ reconstruction: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1405,8 +1383,8 @@ class Wav2Vec2BaseModelOutput(ModelOutput):
heads.
"""
- last_hidden_state: ms.Tensor = None
- extract_features: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ extract_features: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1437,8 +1415,8 @@ class XVectorOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- embeddings: ms.Tensor = None
+ logits: Optional[ms.Tensor] = None
+ embeddings: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1465,7 +1443,7 @@ class BackboneOutput(ModelOutput):
heads.
"""
- feature_maps: Tuple[ms.Tensor] = None
+ feature_maps: Optional[Tuple[ms.Tensor]] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1500,8 +1478,8 @@ class BaseModelOutputWithPoolingAndProjection(ModelOutput):
Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder.
"""
- last_hidden_state: ms.Tensor = None
- pooler_output: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ pooler_output: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
projection_state: Optional[Tuple[ms.Tensor]] = None
@@ -1517,10 +1495,9 @@ class Seq2SeqSpectrogramOutput(ModelOutput):
Spectrogram generation loss.
spectrogram (`ms.Tensor` of shape `(batch_size, sequence_length, num_bins)`):
The predicted spectrogram.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
@@ -1557,8 +1534,8 @@ class Seq2SeqSpectrogramOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- spectrogram: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ spectrogram: Optional[ms.Tensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1579,10 +1556,9 @@ class Seq2SeqTSModelOutput(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
@@ -1626,8 +1602,8 @@ class Seq2SeqTSModelOutput(ModelOutput):
Static features of each time series' in a batch which are copied to the covariates at inference time.
"""
- last_hidden_state: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ last_hidden_state: Optional[ms.Tensor] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1650,10 +1626,9 @@ class Seq2SeqTSPredictionOutput(ModelOutput):
Distributional loss.
params (`ms.Tensor` of shape `(batch_size, num_samples, num_params)`):
Parameters of the chosen distribution.
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+ past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
@@ -1699,7 +1674,7 @@ class Seq2SeqTSPredictionOutput(ModelOutput):
loss: Optional[ms.Tensor] = None
params: Optional[Tuple[ms.Tensor]] = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ past_key_values: Optional[EncoderDecoderCache] = None
decoder_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
cross_attentions: Optional[Tuple[ms.Tensor, ...]] = None
@@ -1722,7 +1697,7 @@ class SampleTSPredictionOutput(ModelOutput):
Sampled values from the chosen distribution.
"""
- sequences: ms.Tensor = None
+ sequences: Optional[ms.Tensor] = None
@dataclass
@@ -1748,7 +1723,7 @@ class MaskedImageModelingOutput(ModelOutput):
"""
loss: Optional[ms.Tensor] = None
- reconstruction: ms.Tensor = None
+ reconstruction: Optional[ms.Tensor] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
diff --git a/mindone/transformers/modeling_rope_utils.py b/mindone/transformers/modeling_rope_utils.py
index 8eb1ae9875..fefeeca5b0 100644
--- a/mindone/transformers/modeling_rope_utils.py
+++ b/mindone/transformers/modeling_rope_utils.py
@@ -83,7 +83,7 @@ def wrapper(self, x, position_ids):
def _compute_default_rope_parameters(
- config: Optional[PretrainedConfig] = None, seq_len: Optional[int] = None, **rope_kwargs
+ config: Optional[PretrainedConfig] = None, seq_len: Optional[int] = None
) -> tuple[Tensor, float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
@@ -92,25 +92,18 @@ def _compute_default_rope_parameters(
The model configuration.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
- rope_kwargs (`Dict`, *optional*):
- BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
- if config is not None and len(rope_kwargs) > 0:
- raise ValueError(
- "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
- f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
- )
- if len(rope_kwargs) > 0:
- base = rope_kwargs["base"]
- dim = rope_kwargs["dim"]
- elif config is not None:
- base = config.rope_theta
- partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- dim = int(head_dim * partial_rotary_factor)
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ head_dim = (
+ config.head_dim
+ if getattr(config, "head_dim", None) is not None
+ else config.hidden_size // config.num_attention_heads
+ )
+ dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
@@ -122,7 +115,6 @@ def _compute_default_rope_parameters(
def _compute_linear_scaling_rope_parameters(
config: Optional[PretrainedConfig] = None,
seq_len: Optional[int] = None,
- **rope_kwargs,
) -> tuple["ms.Tensor", float]:
"""
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
@@ -131,24 +123,14 @@ def _compute_linear_scaling_rope_parameters(
The model configuration.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
- rope_kwargs (`Dict`, *optional*):
- BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
- if config is not None and len(rope_kwargs) > 0:
- raise ValueError(
- "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
- f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
- )
- if len(rope_kwargs) > 0:
- factor = rope_kwargs["factor"]
- elif config is not None:
- factor = config.rope_scaling["factor"]
+ factor = config.rope_scaling["factor"]
# Gets the default RoPE parameters
- inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len, **rope_kwargs)
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
# Then applies linear scaling to the frequencies.
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
@@ -160,7 +142,6 @@ def _compute_linear_scaling_rope_parameters(
def _compute_dynamic_ntk_parameters(
config: Optional[PretrainedConfig] = None,
seq_len: Optional[int] = None,
- **rope_kwargs,
) -> tuple["ms.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
@@ -169,35 +150,34 @@ def _compute_dynamic_ntk_parameters(
The model configuration.
seq_len (`int`, *optional*):
The current sequence length, used to update the dynamic RoPE at inference time.
- rope_kwargs (`Dict`, *optional*):
- BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
- if config is not None and len(rope_kwargs) > 0:
- raise ValueError(
- "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
- f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
- )
- if len(rope_kwargs) > 0:
- base = rope_kwargs["base"]
- dim = rope_kwargs["dim"]
- max_position_embeddings = rope_kwargs["max_position_embeddings"]
- factor = rope_kwargs["factor"]
- elif config is not None:
- base = config.rope_theta
- partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
- head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- dim = int(head_dim * partial_rotary_factor)
- max_position_embeddings = config.max_position_embeddings
- factor = config.rope_scaling["factor"]
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ head_dim = (
+ config.head_dim
+ if getattr(config, "head_dim", None) is not None
+ else config.hidden_size // config.num_attention_heads
+ )
+ dim = int(head_dim * partial_rotary_factor)
+ max_position_embeddings = config.max_position_embeddings
+ factor = config.rope_scaling["factor"]
attention_factor = 1.0 # Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
- seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
+ if seq_len is None:
+ seq_len = max_position_embeddings
+ elif isinstance(seq_len, ms.Tensor):
+ seq_len = mint.maximum(
+ seq_len,
+ ms.tensor(max_position_embeddings, dtype=seq_len.dtype),
+ )
+ else:
+ seq_len = max(seq_len, max_position_embeddings)
# Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
@@ -205,40 +185,52 @@ def _compute_dynamic_ntk_parameters(
return inv_freq, attention_factor
-def _compute_yarn_parameters(
- config: PretrainedConfig, seq_len: Optional[int] = None, **rope_kwargs
-) -> tuple["ms.Tensor", float]:
+def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> tuple["ms.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Please refer to the
- [original paper](https://arxiv.org/abs/2309.00071)
+ [original paper](https://huggingface.co/papers/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
- rope_kwargs (`Dict`, *optional*):
- BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
- # No need to keep BC with yarn, unreleased when this new pattern was created.
- if len(rope_kwargs) > 0:
- raise ValueError(
- f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
- )
-
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
- head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ head_dim = (
+ config.head_dim
+ if getattr(config, "head_dim", None) is not None
+ else config.hidden_size // config.num_attention_heads
+ )
dim = int(head_dim * partial_rotary_factor)
- max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]
+ attention_factor = config.rope_scaling.get("attention_factor")
+ mscale = config.rope_scaling.get("mscale")
+ mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
+
+ # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
+ # values to compute the default attention scaling factor, instead of using `factor`.
+ if "original_max_position_embeddings" in config.rope_scaling:
+ original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
+ factor = config.max_position_embeddings / original_max_position_embeddings
+ else:
+ original_max_position_embeddings = config.max_position_embeddings
+
+ def get_mscale(scale, mscale=1):
+ if scale <= 1:
+ return 1.0
+ return 0.1 * mscale * math.log(scale) + 1.0
# Sets the attention factor as suggested in the paper
- attention_factor = config.rope_scaling.get("attention_factor")
if attention_factor is None:
- attention_factor = 0.1 * math.log(factor) + 1.0
+ if mscale and mscale_all_dim:
+ attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
+ else:
+ attention_factor = get_mscale(factor)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
@@ -270,7 +262,7 @@ def linear_ramp_factor(min, max, dim):
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
- low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
@@ -282,9 +274,7 @@ def linear_ramp_factor(min, max, dim):
return inv_freq, attention_factor
-def _compute_longrope_parameters(
- config: PretrainedConfig, seq_len: Optional[int] = None, **rope_kwargs
-) -> tuple["ms.Tensor", float]:
+def _compute_longrope_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> tuple["ms.Tensor", float]:
"""
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
[original implementation](https://github.com/microsoft/LongRoPE)
@@ -293,23 +283,19 @@ def _compute_longrope_parameters(
The model configuration.
seq_len (`int`, *optional*):
The current sequence length.
- rope_kwargs (`Dict`, *optional*):
- BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
- # No need to keep BC with longrope, unreleased when this new pattern was created.
- if len(rope_kwargs) > 0:
- raise ValueError(
- "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
- f"{rope_kwargs}"
- )
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
- head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ head_dim = (
+ config.head_dim
+ if getattr(config, "head_dim", None) is not None
+ else config.hidden_size // config.num_attention_heads
+ )
dim = int(head_dim * partial_rotary_factor)
long_factor = config.rope_scaling["long_factor"]
short_factor = config.rope_scaling["short_factor"]
@@ -343,9 +329,7 @@ def _compute_longrope_parameters(
return inv_freq, attention_factor
-def _compute_llama3_parameters(
- config: PretrainedConfig, seq_len: Optional[int] = None, **rope_kwargs
-) -> tuple["ms.Tensor", float]:
+def _compute_llama3_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> tuple["ms.Tensor", float]:
"""
Computes the inverse frequencies for llama 3.1.
@@ -354,14 +338,12 @@ def _compute_llama3_parameters(
The model configuration.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
- rope_kwargs (`Dict`, *optional*):
- BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`mindspore.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""
# Gets the default RoPE parameters
- inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len, **rope_kwargs)
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
factor = config.rope_scaling["factor"] # `8` in the original implementation
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
@@ -464,7 +446,14 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
- optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
+ optional_keys = {
+ "attention_factor",
+ "beta_fast",
+ "beta_slow",
+ "original_max_position_embeddings",
+ "mscale",
+ "mscale_all_dim",
+ }
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
@@ -501,7 +490,11 @@ def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optiona
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
- head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ head_dim = (
+ config.head_dim
+ if getattr(config, "head_dim", None) is not None
+ else config.hidden_size // config.num_attention_heads
+ )
dim = int(head_dim * partial_rotary_factor)
short_factor = rope_scaling.get("short_factor")
diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py
index ca18c93474..7fc96f444f 100644
--- a/mindone/transformers/modeling_utils.py
+++ b/mindone/transformers/modeling_utils.py
@@ -21,10 +21,11 @@
import json
import os
import re
+import sys
import warnings
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union
+from typing import Any, Callable, MutableMapping, Optional, Union
from transformers.configuration_utils import PretrainedConfig
from transformers.dynamic_module_utils import custom_object_save
@@ -67,7 +68,8 @@
from .integrations.flash_attention import flash_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .loss.loss_utils import LOSS_MAPPING
-from .mindspore_adapter import TORCH_TO_MINDSPORE_DTYPE_MAP, dtype_to_str
+from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
+from .mindspore_adapter import dtype_to_str
from .mindspore_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@@ -77,7 +79,8 @@
prune_linear_layer,
)
from .modeling_attn_mask_utils import dtype_to_min
-from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available
+from .utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder
+from .utils.import_utils import is_sdpa_available
if is_safetensors_available():
from safetensors import safe_open
@@ -85,6 +88,27 @@
# from mindone.safetensors.mindspore import load_file as safe_load_file
from mindone.safetensors.mindspore import save_file as safe_save_file
+# DO NOT MODIFY, KEPT FOR BC ONLY
+VLMS = [
+ "aria",
+ "ayavision",
+ "colpali",
+ "emu3",
+ "fuyu",
+ "gotocr2",
+ "gemma3",
+ "internvl",
+ "llava", # all llava prefixed models fall under this check
+ "mistral3",
+ "mllama",
+ "paligemma",
+ "shieldgemma2",
+ "qwen2vl",
+ "qwen2_5_vl",
+ "videollava",
+ "vipllava",
+]
+
logger = logging.get_logger(__name__)
_init_weights = True
@@ -126,12 +150,10 @@ def _get_pt2ms_mapped_k(mappings, has_prefix_module, expects_prefix_module, load
else mappings.get(s, (s, lambda x: x))[0]
for s in loaded_keys
]
- loaded_keys = [".".join([prefix, s]) for s in loaded_keys]
elif not has_prefix_module and expects_prefix_module:
loaded_keys = [
mappings.get(".".join([prefix, s]), (".".join([prefix, s]), lambda x: x))[0] for s in loaded_keys
]
- loaded_keys = [s[len(prefix) + 1 :] if s.startswith(prefix) else s for s in loaded_keys]
else:
loaded_keys = [mappings.get(s, (s, lambda x: x))[0] for s in loaded_keys]
return loaded_keys
@@ -147,17 +169,12 @@ def _convert_state_dict(m, state_dict_pt, prefix=""):
for name, param in m.parameters_and_names():
name_ms = param.name
length = len(prefix) + 1
- # State dict name conversion is added for dealing with the condition that model key and state_dict key mismatch
if name_pt.startswith(prefix):
- # if state_dict has prefix, check if model has prefix
- # When the prefix and the end of the name (such as embedding_table and weight) are removed, the consistency is judged
- # if yes, slice prefix from state_dict key
+ # When name_ms and name_pt match and name_pt has prefix, name_pt would be sliced
if name_ms.rsplit(".", 1)[0] == name_pt.rsplit(".", 1)[0][length:] or name_ms == name_pt[length:]:
name_pt = name_pt[length:]
elif not name_pt.startswith(prefix):
- # if state_dict does not have prefix, check if model has prefix
- # When the prefix and the end of the name (such as embedding_table and weight) are removed, the consistency is judged
- # if no, add prefix to state_dict key
+ # When name_ms and name_pt match and name_ms has prefix, prefix would be added to name_pt
if name_pt.rsplit(".", 1)[0] == name_ms.rsplit(".", 1)[0][length:] or name_pt == name_ms[length:]:
name_pt = ".".join([prefix, name_pt])
name_ms, data_mapping = pt2ms_mappings.get(name_pt, (name_pt, lambda x: x))
@@ -230,7 +247,7 @@ def dtype_byte_size(dtype):
def shard_checkpoint(
- state_dict: Dict[str, Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
+ state_dict: dict[str, Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
@@ -249,7 +266,7 @@ def shard_checkpoint(
Args:
- state_dict (`Dict[str, Tensor]`): The state dictionary of a model to save.
+ state_dict (`dict[str, Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
@@ -380,13 +397,69 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name
+def _get_mindspore_dtype(
+ cls,
+ mindspore_dtype: Optional[Union[str, ms.Type, dict]],
+ checkpoint_files: Optional[list[str]],
+ config: PretrainedConfig,
+ sharded_metadata: Optional[dict],
+ state_dict: Optional[dict],
+ weights_only: bool,
+ is_sharded: bool,
+):
+ # set dtype to instantiate the model under:
+ # 1. If mindspore_dtype is not None, we use that dtype
+ # 2. If mindspore_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
+ # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
+ # we also may have config.torch_dtype available, but we won't rely on it till v5
+
+ if mindspore_dtype is not None:
+ config.mindspore_dtype = dtype_to_str(mindspore_dtype)
+ for sub_config_key in config.sub_configs.keys():
+ sub_config = getattr(config, sub_config_key)
+ sub_config.mindspore_dtype = mindspore_dtype
+ if isinstance(mindspore_dtype, str):
+ if mindspore_dtype == "auto":
+ if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
+ mindspore_dtype = config.torch_dtype
+ logger.info(f"Will use dtype={mindspore_dtype} as defined in model's config object")
+ else:
+ if is_sharded and "dtype" in sharded_metadata:
+ mindspore_dtype = sharded_metadata["dtype"]
+ elif not is_sharded:
+ mindspore_dtype = get_state_dict_dtype(state_dict)
+ else:
+ one_state_dict = load_state_dict(checkpoint_files[0])
+ mindspore_dtype = get_state_dict_dtype(one_state_dict)
+ del one_state_dict # free CPU memory
+ logger.info(
+ f"Since the `torch_dtype` attribute can't be found in model's config object, "
+ f"will use dtype={mindspore_dtype} as derived from model's weights"
+ )
+ else:
+ raise ValueError(
+ f'`mindspore_dtype` can be either `ms.Type` or `"auto"`, but received {mindspore_dtype}'
+ )
+ # TODO: We cannot set default mindspore dtype!
+ else:
+ # set fp32 as the default dtype for BC
+ # TODO: We cannot get default mindspore dtype! Therefore, we set default dtype to ms.float32
+ default_dtype = dtype_to_str(ms.float32)
+ config.mindspore_dtype = default_dtype
+ for key in config.sub_configs.keys():
+ value = getattr(config, key)
+ value.mindspore_dtype = default_dtype
+
+ return config, mindspore_dtype
+
+
def _find_missing_and_unexpected_keys(
cls,
model: "PreTrainedModel",
- original_checkpoint_keys: List[str],
- checkpoint_keys: List[str],
+ original_checkpoint_keys: list[str],
+ checkpoint_keys: list[str],
loading_base_model_from_task_state_dict: bool,
-) -> Tuple[List[str], List[str]]:
+) -> tuple[list[str], list[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
"""
@@ -435,10 +508,6 @@ def _get_name(self):
return self.__class__.__name__
def to(self, dtype: Optional[ms.Type] = None):
- # FIXME: In ms 2.6.0 `tensor.set_dtype()` encountered a bug that it occurs wrong values.
- # Resume to use self.register_buffer() in network and set dtype for buffer tensors after ms2.7.0 launched.
- # Now we use `Parameter` and `Parameter.set_dtype()` instead.
-
for p in self.get_parameters():
p.set_dtype(dtype)
return self
@@ -508,7 +577,7 @@ def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
return extended_attention_mask
def get_extended_attention_mask(
- self, attention_mask: Tensor, input_shape: Tuple[int], dtype: ms.float32 = None
+ self, attention_mask: Tensor, input_shape: tuple[int], dtype: ms.float32 = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
@@ -516,7 +585,7 @@ def get_extended_attention_mask(
Arguments:
attention_mask (`Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
- input_shape (`Tuple[int]`):
+ input_shape (`tuple[int]`):
The shape of the input to the model.
Returns:
@@ -625,7 +694,7 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
return sum(total_numel)
- def estimate_tokens(self, input_dict: Dict[str, Union[ms.Tensor, Any]]) -> int:
+ def estimate_tokens(self, input_dict: dict[str, Union[ms.Tensor, Any]]) -> int:
"""
Helper function to estimate the total number of tokens from the model inputs.
@@ -646,7 +715,7 @@ def estimate_tokens(self, input_dict: Dict[str, Union[ms.Tensor, Any]]) -> int:
self.warnings_issued["estimate_tokens"] = True
return 0
- def floating_point_ops(self, input_dict: Dict[str, Union[ms.Tensor, Any]], exclude_embeddings: bool = True) -> int:
+ def floating_point_ops(self, input_dict: dict[str, Union[ms.Tensor, Any]], exclude_embeddings: bool = True) -> int:
"""
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
@@ -671,7 +740,99 @@ def floating_point_ops(self, input_dict: Dict[str, Union[ms.Tensor, Any]], exclu
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
-class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
+class EmbeddingAccessMixin:
+ """
+ Base utilities to regroup getters and setters for embeddings.
+ Introduces the `input_layer_embed` attribute, which indicates
+ where the input embeddings come from and where they
+ should be set.
+ """
+
+ _input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
+
+ def get_input_embeddings(self) -> nn.Cell:
+ """
+ Returns the model's input embeddings.
+
+ Returns:
+ `nn.Cell`: A mindspore module mapping vocabulary to hidden states.
+ """
+
+ # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
+ # for most NLP models), and if so, return it.
+
+ name = getattr(self, "_input_embed_layer", "embed_tokens")
+
+ if (default_embedding := getattr(self, name, None)) is not None:
+ return default_embedding
+ # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
+
+ if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
+ return self.model.embed_tokens
+
+ # 3) vanilla decoder‑only architectures
+ elif hasattr(self, "embed_tokens"):
+ return self.embed_tokens
+ else:
+ base_model = getattr(self, "base_model_prefix", None)
+ if base_model is not None:
+ base_model = getattr(self, base_model, None)
+ if base_model is not None and base_model is not self:
+ return base_model.get_input_embeddings()
+ raise NotImplementedError(
+ f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
+ "please override in the subclass."
+ )
+
+ def set_input_embeddings(self, value: nn.Cell):
+ """Fallback setter that handles **~70 %** of models in the code‑base.
+
+ Order of attempts:
+ 1. `self.model.embed_tokens`
+ 2. `self.embed_tokens`
+ 3. delegate to the *base model* if one exists
+ 4. otherwise raise `NotImplementedError` so subclasses still can (and
+ should) override for exotic layouts.
+ """
+
+ # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
+ name = getattr(self, "_input_embed_layer", "embed_tokens")
+ if hasattr(self, "model") and hasattr(self.model, name):
+ setattr(self.model, name, value)
+ # 2) as well as vanilla decoder‑only architectures
+ elif hasattr(self, name):
+ setattr(self, name, value)
+ # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
+ elif getattr(self, self.base_model_prefix, self) is not self:
+ base_model = getattr(self, self.base_model_prefix, self)
+ base_model.set_input_embeddings(value)
+ else:
+ raise NotImplementedError(
+ f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
+ )
+
+ def get_output_embeddings(self):
+ if not hasattr(self, "lm_head"):
+ return None
+ try:
+ # Speech / vision backbones raise here, so we return None.
+ # Legit use of get_input_embs?
+ self.get_input_embeddings()
+ except NotImplementedError:
+ return None
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ """
+ Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
+ """
+ if getattr(self, "lm_head"):
+ self.lm_head = new_embeddings
+
+
+class PreTrainedModel(
+ nn.Cell, EmbeddingAccessMixin, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin
+):
r"""
Base class for all models.
@@ -704,10 +865,16 @@ class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin
main_input_name = "input_ids"
model_tags = None
+ _checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
+
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
+
_keep_in_fp32_modules = None
+ # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
+ # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
+ _keep_in_fp32_modules_strict = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@@ -724,19 +891,23 @@ class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin
is_parallelizable = False
supports_gradient_checkpointing = False
+ _is_stateful = False
- # Flash Attention 2 support
- _supports_flash_attn_2 = False
+ # Flash Attention support
+ _supports_flash_attn = False
# SDPA support
_supports_sdpa = False
+ _can_compile_fullgraph = False
+
# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
_supports_static_cache = False
- # Has support for dynamic model input?
+ # control for padding and static cache
_supports_dynamic_input = False
+ _supports_jit = False
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False
@@ -745,11 +916,54 @@ class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
_supports_attention_backend = False
+ _can_record_outputs = None
+
+ @property
+ def can_record_outputs(self) -> dict[str, OutputRecorder]:
+ """
+ Maps output names (e.g., "attentions", "hidden_states")
+ to either:
+ - A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
+ * index=0 for "hidden_states"
+ * index=1 for "attentions"
+ - Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
+
+ Examples:
+ These two are equivalent:
+
+ ```python
+ _can_record_outputs = {
+ "attentions": LlamaAttention,
+ "hidden_states": LlamaDecoderLayer
+ }
+
+ _can_record_outputs = {
+ "attentions": OutputRecorder(LlamaAttention, index=1),
+ "hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
+ }
+ ```
+
+ This means you can record outputs from the same class, by specifying a layer name. Before
+ collecting outputs, we check that they come from this layer.
+
+ If you have cross attention that come from `LlamaAttention` and self attention that also
+ come from `LlamaAttention` but from `self_attn` you can do this:
+
+ ```python
+ class LlamaModel(PreTrainedModel):
+ _can_record_outputs = {
+ "attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"),
+ "cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn")
+ }
+
+ ```
+ """
+ return self._can_record_outputs or {}
@property
- def dummy_inputs(self) -> Dict[str, Tensor]:
+ def dummy_inputs(self) -> dict[str, Tensor]:
"""
- `Dict[str, Tensor]`: Dummy inputs to do a forward pass in the network.
+ `dict[str, Tensor]`: Dummy inputs to do a forward pass in the network.
"""
return {"input_ids": Tensor(DUMMY_INPUTS)}
@@ -760,6 +974,30 @@ def framework(self) -> str:
"""
return "ms"
+ def __init_subclass__(cls, **kwargs):
+ super().__init_subclass__(**kwargs)
+ # For BC we keep the original `config_class` definition in case
+ # there is a `config_class` attribute (e.g. remote code models),
+ # otherwise we derive it from the annotated `config` attribute.
+
+ # defined in this particular subclass
+ child_annotation = cls.__dict__.get("__annotations__", {}).get("config", None)
+ child_attribute = cls.__dict__.get("config_class", None)
+
+ # defined in the class (this subclass or any parent class)
+ full_annotation = cls.__dict__.get("config", None)
+ full_attribute = cls.config_class
+
+ # priority (child class_config -> child annotation -> global class_config -> global annotation)
+ if child_attribute is not None:
+ cls.config_class = child_attribute
+ elif child_annotation is not None:
+ cls.config_class = child_annotation
+ elif full_attribute is not None:
+ cls.config_class = full_attribute
+ elif full_annotation is not None:
+ cls.config_class = full_annotation
+
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PretrainedConfig):
@@ -768,19 +1006,24 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
"`PretrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
- if not getattr(config, "_attn_implementation_autoset", False):
- # config usually has a `mindspore_dtype` but we need the next line for the `no_super_init` tests
- # TODO mindspore does not have get_default_dtype api
- dtype = ms.float32
- if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
- if isinstance(config.torch_dtype, str):
- dtype = getattr(ms, config.torch_dtype)
- else:
- dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[str(config.torch_dtype)]
- config = self._autoset_attn_implementation(config, mindspore_dtype=dtype)
# Save config and origin of the pretrained weights if given in model
self.config = config
+ # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
+ # setting it recursively)
+ # TODO set default implementation to "eager" because of immature sdpa attention
+ if self.config._attn_implementation == "sdpa":
+ self.config._attn_implementation = "eager"
+ # warn user that sdpa is not supported
+ logger.warning(
+ "SDPA is not supported yet. Falling back to eager attention implementation. This warning can be removed using the argument "
+ '`attn_implementation="eager"` when loading the model. '
+ "Example: `model = AutoModel.from_pretrained('openai/whisper-tiny', attn_implementation='eager')`"
+ )
+ self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
+ self.config._attn_implementation, is_init_check=True
+ )
+
# for initialization of the loss
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
@@ -799,6 +1042,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
# when a different component (e.g. language_model) is used.
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
+ self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
+
+ self._no_split_modules = self._no_split_modules or []
+ _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
def post_init(self):
"""
@@ -807,86 +1054,6 @@ def post_init(self):
"""
self.init_weights()
- @classmethod
- def _autoset_attn_implementation(
- cls,
- config,
- use_flash_attention_2: bool = False,
- mindspore_dtype=None,
- ):
- """
- Automatically checks and dispatches to a default attention implementation. In order of priority:
- 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
- 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
- 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
- 4. The default model's implementation otherwise (`LlamaAttention` for example) .
- """
- # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
- # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
- # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
- requested_attn_implementation = None
- if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
- if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
- raise ValueError(
- f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were '
- f"used when loading the model, which are not compatible."
- ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
- )
-
- if config._attn_implementation not in ["eager", "paged_attention"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
- message = (
- f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. '
- f'The only possible arguments are `attn_implementation="eager"`'
- f" (manual attention implementation)"
- )
- if cls._supports_flash_attn_2:
- message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
- if cls._supports_sdpa:
- message += ', `"attn_implementation=sdpa"` (implementation using scaled_dot_product_attention)'
- raise ValueError(message + ".")
-
- # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the
- # user-provided config, with hard checks that the requested attention implementation is available.
- requested_attn_implementation = config._attn_implementation_internal
-
- # Composite models consisting of several PretrainedModels have to specify attention impl as a dict
- # where keys are sub-config names. But most people will specify one `str` which means that should dispatch it
- # for all sub-models.
- # Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
- # Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
- # If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
- for key in config.sub_configs.keys():
- sub_config = getattr(config, key)
- curr_attn_implementation = (
- requested_attn_implementation
- if not isinstance(requested_attn_implementation, dict)
- else requested_attn_implementation.get(key, None)
- )
- # For models with backbone sub-config might be not initialized
- if sub_config is not None:
- sub_config._attn_implementation_internal = curr_attn_implementation
-
- if use_flash_attention_2:
- logger.warning_once(
- "The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a "
- 'future release. Please use `attn_implementation="flash_attention_2"` instead.'
- )
- config._attn_implementation = "flash_attention_2"
- if config._attn_implementation == "flash_attention_2":
- cls._check_and_enable_flash_attn_2(
- config,
- mindspore_dtype=mindspore_dtype,
- hard_check_only=False,
- )
- elif requested_attn_implementation in [None, "sdpa"]:
- # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
- config = cls._check_and_enable_sdpa(
- config,
- hard_check_only=False if requested_attn_implementation is None else True,
- )
-
- return config
-
@property
def base_model(self) -> nn.Cell:
"""
@@ -914,12 +1081,13 @@ def can_generate(cls) -> bool:
continue
if "PreTrainedModel" not in str(base) and base.can_generate():
return True
- # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
+
+ # Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
- if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
- logger.warning_once(
+ if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
+ logger.warning(
f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
- "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
+ "defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
"to call `generate` and other related functions."
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
@@ -929,50 +1097,9 @@ def can_generate(cls) -> bool:
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
"to update it."
)
- return True
# Otherwise, can't generate
return False
- @classmethod
- def _check_and_enable_flash_attn_2(
- cls,
- config,
- mindspore_dtype=None,
- hard_check_only: bool = False,
- ) -> PretrainedConfig:
- """
- Checks the availability of Flash Attention 2 and compatibility with the current model.
-
- If all checks pass and `hard_check_only` is False, the method will set the config attribute
- `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
- """
- if not cls._supports_flash_attn_2:
- raise ValueError(
- f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
- f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
- " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
- )
-
- if not is_flash_attn_2_available():
- raise ImportError("FlashAttention2 has been toggled on, but it cannot be used due to some error")
-
- if mindspore_dtype is None:
- logger.warning_once(
- "You are attempting to use Flash Attention 2.0 without specifying a MindSpore dtype. This might lead to unexpected behaviour"
- )
- elif mindspore_dtype is not None and mindspore_dtype not in [ms.float16, ms.bfloat16]:
- logger.warning_once(
- "Flash Attention 2.0 only supports ms.float16 and ms.bfloat16 dtypes, but"
- f" the current dype in {cls.__name__} is {mindspore_dtype}. You should run training or inference using "
- f"Automatic Mixed-Precision via the `network=auto_mix_precision(network, ...)` decorator,"
- " or load the model with the `mindspore_dtype` argument. Example: `model = "
- 'AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", mindspore_dtype=ms.float16)`'
- )
-
- if not hard_check_only:
- config._attn_implementation = "flash_attention_2"
- return config
-
@property
def loss_function(self):
if hasattr(self, "_loss_function"):
@@ -996,33 +1123,6 @@ def loss_function(self, value):
def is_backend_compatible(cls):
return cls._supports_attention_backend
- @classmethod
- def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
- """
- Checks the availability of SDPA for a given model.
-
- If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation`
- to "flash_attention_2" so that the model can initialize the correct attention module.
- """
- if hard_check_only:
- if not cls._supports_sdpa:
- raise ValueError(
- f"{cls.__name__} does not support an attention implementation through `scaled_dot_product_attention` yet."
- " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. "
- "If you believe this error is a bug, please open an issue in Transformers GitHub repository and "
- 'load your model with the argument `attn_implementation="eager"` meanwhile. Example: '
- '`model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
- )
- if not is_sdpa_available():
- raise ImportError("SDPA requirements in Transformers are not met.")
-
- if not is_sdpa_available() or not cls._supports_sdpa:
- return config
-
- if not hard_check_only:
- config._attn_implementation = "sdpa"
- return config
-
@classmethod
def _from_config(cls, config, **kwargs):
"""
@@ -1043,27 +1143,19 @@ def _from_config(cls, config, **kwargs):
if isinstance(mindspore_dtype, str):
mindspore_dtype = getattr(ms, mindspore_dtype)
elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type):
+ TORCH_TO_MINDSPORE_DTYPE_MAP = {
+ "torch.float32": ms.float32,
+ "torch.bfloat16": ms.bfloat16,
+ "torch.float16": ms.float16,
+ }
mindspore_dtype = str(mindspore_dtype)
mindspore_dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[mindspore_dtype]
- use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
-
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
- if config._attn_implementation_internal is not None:
- # In this case, the config has been created with the attn_implementation set by the user, which we
- # should respect.
- attn_implementation = config._attn_implementation_internal
- else:
- attn_implementation = None
-
- config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
- if not getattr(config, "_attn_implementation_autoset", False):
- config = cls._autoset_attn_implementation(
- config,
- use_flash_attention_2=use_flash_attention_2,
- mindspore_dtype=mindspore_dtype,
- )
+ # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
+ if "attn_implementation" in kwargs:
+ config._attn_implementation = kwargs.pop("attn_implementation")
model = cls(config, **kwargs)
@@ -1077,40 +1169,263 @@ def _from_config(cls, config, **kwargs):
return model
- def get_input_embeddings(self) -> nn.Cell:
+ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
"""
- Returns the model's input embeddings.
+ Check the availability of Flash Attention 2 for a given model.
- Returns:
- `nn.Cell`: A mindspore cell mapping vocabulary to hidden states.
+ Args:
+ is_init_check (`bool`, *optional*):
+ Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+ fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+ BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+ before instantiating the full models if we know that the model does not support the requested attention.
"""
- base_model = getattr(self, self.base_model_prefix, self)
- if base_model is not self:
- return base_model.get_input_embeddings()
- else:
- raise NotImplementedError
+ mindspore_dtype = self.config.torch_dtype
+ if isinstance(mindspore_dtype, str):
+ mindspore_dtype = getattr(ms, mindspore_dtype)
+ elif mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type):
+ TORCH_TO_MINDSPORE_DTYPE_MAP = {
+ "torch.float32": ms.float32,
+ "torch.bfloat16": ms.bfloat16,
+ "torch.float16": ms.float16,
+ }
+ mindspore_dtype = str(mindspore_dtype)
+ mindspore_dtype = TORCH_TO_MINDSPORE_DTYPE_MAP[mindspore_dtype]
- def set_input_embeddings(self, value: nn.Cell):
+ # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
+ if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)):
+ raise ValueError(
+ f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
+ f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
+ " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+ )
+
+ # fixme variable is assigned but never used
+ # if not is_flash_attn_2_available():
+ # preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
+ # install_message = "Please refer to the documentation of
+ # https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
+
+ if mindspore_dtype is None:
+ logger.warning_once(
+ "You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour"
+ )
+ elif mindspore_dtype is not None and mindspore_dtype not in [ms.float16, ms.bfloat16]:
+ logger.warning_once(
+ "Flash Attention 2 only supports ms.float16 and ms.bfloat16 dtypes, but"
+ f" the current dype in {self.__class__.__name__} is {mindspore_dtype}. You should run training or inference using Automatic Mixed-Precision,"
+ " or load the model with the `torch_dtype` argument. "
+ 'Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", mindspore_dtype=ms.float16)`'
+ )
+
+ # With the early check, the parameters are not yet initalized correctly
+ if not is_init_check:
+ if getattr(self, "use_bettertransformer", False):
+ raise ValueError(
+ "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers "
+ "by doing model.reverse_bettertransformer()"
+ )
+
+ # If no error raise by this point, we can return `True`
+ return True
+
+ def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool:
"""
- Set model's input embeddings.
+ Check the availability of SDPA for a given model.
Args:
- value (`nn.Cell`): A cell mapping vocabulary to hidden states.
+ is_init_check (`bool`, *optional*):
+ Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+ fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+ BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+ before instantiating the full models if we know that the model does not support the requested attention.
"""
- base_model = getattr(self, self.base_model_prefix, self)
- if base_model is not self:
- base_model.set_input_embeddings(value)
- else:
- raise NotImplementedError
+ if not self._supports_sdpa:
+ raise ValueError(
+ f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
+ " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
+ " this error is a bug, please open an issue in Transformers GitHub repository and "
+ 'load your model with the argument `attn_implementation="eager"` meanwhile. '
+ 'Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
+ )
+ if not is_sdpa_available():
+ raise ImportError(
+ "MindSpore SDPA requirements in Transformers are not met. Use `attn_implementation='eager'` instead."
+ )
+
+ return True
- def get_output_embeddings(self) -> nn.Cell:
+ def _check_and_adjust_attn_implementation(
+ self, attn_implementation: Optional[str], is_init_check: bool = False
+ ) -> str:
"""
- Returns the model's output embeddings.
+ Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
+ it matches hf kernels pattern.
+
+ Args:
+ attn_implementation (`str` or `None`):
+ The attention implementation to check for existence/validity.
+ is_init_check (`bool`, *optional*):
+ Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+ fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+ BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+ before instantiating the full models if we know that the model does not support the requested attention.
Returns:
- `nn.Cell`: A mindspore cell mapping hidden states to vocabulary.
+ `str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from
+ None to sdpa (to potentially eager).
+ """
+ applicable_attn_implementation = "eager" if attn_implementation is None else attn_implementation
+ if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
+ # Extract repo_id and kernel_name from the string
+ if ":" in applicable_attn_implementation:
+ repo_id, kernel_name = attn_implementation.split(":")
+ kernel_name = kernel_name.strip()
+ else:
+ repo_id = attn_implementation
+ kernel_name = None
+ repo_id = repo_id.strip()
+ try:
+ # fixme there is no implementation for kernel in mindspore
+ ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
+ applicable_attn_implementation = repo_id
+ except Exception as e:
+ logger.warning_once(
+ f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
+ "default attention implementation instead (sdpa if available, eager otherwise)."
+ )
+ applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
+ if applicable_attn_implementation not in ["eager", "paged_attention"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
+ message = (
+ f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
+ '`attn_implementation="eager"` (manual attention implementation)'
+ )
+ # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
+ if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False):
+ message += (
+ ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
+ ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
+ )
+ if self._supports_sdpa:
+ message += ', `"attn_implementation=sdpa"` '
+ if self._supports_flex_attn:
+ message += ', `"attn_implementation=flex_attention"`'
+ raise ValueError(message + ".")
+
+ # Perform relevant checks
+ if applicable_attn_implementation == "flash_attention_2":
+ self._flash_attn_2_can_dispatch(is_init_check)
+ elif applicable_attn_implementation == "flash_attention_3":
+ raise NotImplementedError(
+ "mindone.transformers does not support fa3 yet. Please use eager attention instead!"
+ )
+ elif applicable_attn_implementation == "flex_attention":
+ raise NotImplementedError(
+ "mindone.transformers does not support flex attention yet. Please use eager attention instead!"
+ )
+ elif applicable_attn_implementation == "sdpa":
+ # Sdpa is the default, so we try it and fallback to eager otherwise when not possible
+ try:
+ self._sdpa_can_dispatch(is_init_check)
+ except (ValueError, ImportError) as e:
+ # In this case, sdpa was requested explicitly, but we can't use it, so let's raise
+ if attn_implementation == "sdpa":
+ raise e
+ applicable_attn_implementation = "eager"
+
+ return applicable_attn_implementation
+
+ @classmethod
+ def _can_set_attn_implementation(cls) -> bool:
+ """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
+ opening the file, but avoids maintaining yet another property flag.
+ """
+ class_file = sys.modules[cls.__module__].__file__
+ with open(class_file, "r") as f:
+ code = f.read()
+ # heuristic -> if we find those patterns, the model uses the correct interface
+ return "eager_attention_forward" in code and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
+
+ def set_attn_implementation(self, attn_implementation: Union[str, dict]):
"""
- return None # Overwrite for models with output embeddings
+ Set the requested `attn_implementation` for this model.
+
+ Args:
+ attn_implementation (`str` or `dict`):
+ The attention implementation to set for this model. It can be either a `str`, in which case it will be
+ dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
+ submodel will dispatch the corresponding value.
+ """
+ requested_implementation = (
+ attn_implementation
+ if not isinstance(attn_implementation, dict)
+ else attn_implementation.get("", self.config._attn_implementation)
+ )
+
+ # At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
+ # warn the user that the requested value is not working
+ if requested_implementation != self.config._attn_implementation:
+ # In this case, raise
+ if not self._can_set_attn_implementation():
+ logger.warning(
+ f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it "
+ "does not follow the functional approach based on AttentionInterface "
+ "(see https://huggingface.co/docs/transformers/en/attention_interface)"
+ )
+ else:
+ try:
+ applicable_attn_implementation = self._check_and_adjust_attn_implementation(
+ requested_implementation, is_init_check=False
+ )
+ # Apply the change (on the internal attr, to avoid setting it recursively)
+ self.config._attn_implementation_internal = applicable_attn_implementation
+ except (ValueError, ImportError) as e:
+ logger.warning(
+ f"Impossible to set the requested `attn_implementation`. The following error was captured: {str(e)}"
+ )
+
+ subconfigs_changed = set()
+ # Apply it to all submodels as well
+ for submodule in self.modules():
+ # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
+ # e.g. ForCausalLM has a Model inside, but no need to check it again)
+ if (
+ submodule is not self
+ and isinstance(submodule, PreTrainedModel)
+ and submodule.config.__class__ != self.config.__class__
+ ):
+ sub_implementation = attn_implementation
+ if isinstance(attn_implementation, dict):
+ for subconfig_key in self.config.sub_configs:
+ # We need to check for exact object match here, with `is`
+ if getattr(self.config, subconfig_key) is submodule.config:
+ sub_implementation = attn_implementation.get(
+ subconfig_key, submodule.config._attn_implementation
+ )
+ break
+ submodule.set_attn_implementation(sub_implementation)
+ subconfigs_changed.add(submodule.config.__class__)
+
+ # We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
+ for subconfig_key in self.config.sub_configs:
+ subconfig = getattr(self.config, subconfig_key)
+ requested_implementation = (
+ attn_implementation
+ if not isinstance(attn_implementation, dict)
+ else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
+ )
+ # This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
+ if (
+ subconfig.__class__ not in subconfigs_changed
+ and requested_implementation != subconfig._attn_implementation
+ and requested_implementation in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
+ ):
+ subconfig._attn_implementation_internal = requested_implementation
+ logger.warning(
+ f"We set the attention implementation for the sub-config `{subconfig_key}` to `{requested_implementation}` "
+ "without finding the associated sub-model. For this reason we could not check if the model supports it. "
+ "You may encounter undefined behavior."
+ )
def _init_weights(self, module):
"""
@@ -1161,8 +1476,8 @@ def tie_weights(self):
def _tie_encoder_decoder_weights(
encoder: nn.Cell, decoder: nn.Cell, base_model_prefix: str, base_encoder_name: str
):
- uninitialized_encoder_weights: List[str] = []
- tied_weights: List[str] = []
+ uninitialized_encoder_weights: list[str] = []
+ tied_weights: list[str] = []
if decoder.__class__ != encoder.__class__:
logger.info(
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
@@ -1174,7 +1489,7 @@ def tie_encoder_to_decoder_recursively(
encoder_pointer: nn.Cell,
module_name: str,
base_encoder_name: str,
- uninitialized_encoder_weights: List[str],
+ uninitialized_encoder_weights: list[str],
depth=0,
total_decoder_name="",
total_encoder_name="",
@@ -1539,7 +1854,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int):
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
)
- def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
+ def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
raise NotImplementedError(
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
@@ -1617,12 +1932,12 @@ def save_pretrained(
If specified, weights are saved in the format pytorch_model..bin.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
save_peft_format (`bool`, *optional*, defaults to `True`):
For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
- keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can
+ keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
disable this behaviours by setting `save_peft_format` to `False`.
- kwargs (`Dict[str, Any]`, *optional*):
+ kwargs (`dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)
@@ -1664,10 +1979,6 @@ def save_pretrained(
# we currently don't use this setting automatically, but may start to use with v5
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = repr(dtype).split(".")[1]
- model_to_save.config.mindspore_dtype = repr(dtype).split(".")[1]
- for sub in ("text_config", "vision_config"):
- if hasattr(model_to_save.config, sub):
- getattr(model_to_save.config, sub).mindspore_dtype = repr(dtype).split(".")[1]
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
@@ -1711,7 +2022,7 @@ def save_pretrained(
if save_peft_format:
logger.info(
"To match the expected format of the PEFT library, all keys of the state dict of adapters will "
- "be pre-pended with `base_model.model`."
+ "be prepended with `base_model.model`."
)
peft_state_dict = {}
for key, value in state_dict.items():
@@ -1735,6 +2046,25 @@ def save_pretrained(
if state_dict is None:
state_dict = {k: v for k, v in model_to_save.parameters_and_names()}
+ if any(
+ allowed_name in class_name.__name__.lower()
+ for class_name in self.__class__.__mro__[:-1]
+ for allowed_name in VLMS
+ ):
+ reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
+
+ original_state_dict = {}
+ for key, value in state_dict.items():
+ for pattern, replacement in reverse_key_mapping.items():
+ replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
+ replacement = re.sub(r"\(.*\)", "", replacement)
+ key, n_replace = re.subn(pattern, replacement, key)
+ # Early exit of the loop
+ if n_replace > 0:
+ break
+ original_state_dict[key] = value
+ state_dict = original_state_dict
+
# Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None:
for ignore_key in self._keys_to_ignore_on_save:
@@ -1808,6 +2138,7 @@ def from_pretrained(
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
+ weights_only: bool = True,
**kwargs,
):
r"""
@@ -1856,7 +2187,7 @@ def from_pretrained(
save directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
- state_dict (`Dict[str, Tensor]`, *optional*):
+ state_dict (`dict[str, Tensor]`, *optional*):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own
@@ -1881,7 +2212,7 @@ def from_pretrained(
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
- proxies (`Dict[str, str]`, *optional*):
+ proxies (`dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
@@ -2008,10 +2339,16 @@ def from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
- use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
+ key_mapping = kwargs.pop("key_mapping", None)
+ # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
+ if key_mapping is None and any(
+ allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
+ ):
+ key_mapping = cls._checkpoint_conversion_mapping
+
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
@@ -2110,12 +2447,13 @@ def from_pretrained(
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)
-
- kwarg_attn_imp = kwargs.pop("attn_implementation", None)
- if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
- config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
+ # Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
+ # to correctly redispatch recursively if the kwarg is provided
+ if "attn_implementation" in kwargs:
+ config._attn_implementation = kwargs.pop("attn_implementation")
+
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
@@ -2368,7 +2706,9 @@ def from_pretrained(
with safe_open(resolved_archive_file, framework="np") as f:
metadata = f.metadata()
- if metadata.get("format") in ("np", "pt"):
+ if metadata is None:
+ pass
+ elif metadata.get("format") in ("np", "pt"):
pass
elif metadata.get("format") == "tf":
from_tf = True
@@ -2389,47 +2729,17 @@ def from_pretrained(
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
- # set dtype to instantiate the model under:
- # 1. If mindspore_dtype is not None, we use that dtype
- # 2. If mindspore_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
- # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
- # we also may have config.torch_dtype available, but we won't rely on it till v5
-
- if mindspore_dtype is not None:
- config.mindspore_dtype = dtype_to_str(mindspore_dtype)
- for sub_config_key in config.sub_configs.keys():
- sub_config = getattr(config, sub_config_key)
- sub_config.mindspore_dtype = mindspore_dtype
- if isinstance(mindspore_dtype, str):
- if mindspore_dtype == "auto":
- if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
- mindspore_dtype = config.torch_dtype
- logger.info(f"Will use dtype={mindspore_dtype} as defined in model's config object")
- else:
- if is_sharded and "dtype" in sharded_metadata:
- mindspore_dtype = sharded_metadata["dtype"]
- elif not is_sharded:
- mindspore_dtype = get_state_dict_dtype(state_dict)
- else:
- one_state_dict = load_state_dict(resolved_archive_file[0])
- mindspore_dtype = get_state_dict_dtype(one_state_dict)
- del one_state_dict # free CPU memory
- logger.info(
- f"Since the `torch_dtype` attribute can't be found in model's config object, "
- f"will use dtype={mindspore_dtype} as derived from model's weights"
- )
- else:
- raise ValueError(
- f'`mindspore_dtype` can be either `ms.Type` or `"auto"`, but received {mindspore_dtype}'
- )
- # TODO: We cannot set default mindspore dtype!
- else:
- # TODO: We cannot get default mindspore dtype!
- default_dtype = dtype_to_str(ms.float32)
- config.mindspore_dtype = default_dtype
- for key in config.sub_configs.keys():
- value = getattr(config, key)
- value.mindspore_dtype = default_dtype
+ # Find the correct dtype based on current state
+ config, mindspore_dtype = _get_mindspore_dtype(
+ cls,
+ mindspore_dtype,
+ resolved_archive_file,
+ config,
+ sharded_metadata,
+ state_dict,
+ weights_only,
+ is_sharded,
+ )
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (mindspore_dtype == ms.float16)
@@ -2442,9 +2752,6 @@ def from_pretrained(
config.name_or_path = pretrained_model_name_or_path
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
- config = cls._autoset_attn_implementation(
- config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype
- )
model = cls(config, *model_args, **model_kwargs)
@@ -2489,6 +2796,7 @@ def from_pretrained(
sharded_metadata=sharded_metadata,
dtype=mindspore_dtype,
keep_in_fp32_modules=keep_in_fp32_modules,
+ key_mapping=key_mapping,
)
if _adapter_model_path is not None:
@@ -2541,7 +2849,7 @@ def from_pretrained(
return model
@staticmethod
- def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]:
+ def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
@@ -2554,8 +2862,8 @@ def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]:
def _get_key_renaming_mapping(
self,
- checkpoint_keys: List[str],
- key_mapping: Optional[Dict[str, str]] = None,
+ checkpoint_keys: list[str],
+ key_mapping: Optional[dict[str, str]] = None,
loading_base_model_from_task_state_dict: bool = False,
loading_task_model_from_base_state_dict: bool = False,
):
@@ -2629,7 +2937,7 @@ def _load_pretrained_model(
sharded_metadata=None,
dtype=None,
keep_in_fp32_modules=None,
- key_mapping: Optional[Dict[str, str]] = None,
+ key_mapping: Optional[dict[str, str]] = None,
weights_only: bool = True,
):
model_state_dict = {k: v for k, v in model.parameters_and_names()}
@@ -2748,21 +3056,13 @@ def _find_mismatched_keys(
if state_dict is not None:
# Whole checkpoint
- state_dict = _convert_state_dict(model, state_dict, prefix)
- # In the original PyTorch implementation, this transformation is done via `_register_load_state_dict_pre_hook(load_hook)`.
- # Since MindSpore does not support such hooks, we manually apply the same renaming logic here.
- # This ensures compatibility with checkpoints loaded using the original naming convention.
- if "Mamba" in model.__class__.__name__:
- state_dict_tmp = {}
- for k, v in state_dict.items():
- new_k = k.replace("embedding.", "embeddings.") if "embedding." in k else k
- state_dict_tmp[new_k] = v
- state_dict = state_dict_tmp
-
+ # checkpoint mapping from pt to hf
matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s]
if matching:
# Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
+ # checkpoint mapping from hf to ms
+ state_dict = _convert_state_dict(model, state_dict, prefix)
mismatched_keys = _find_mismatched_keys(
state_dict,
@@ -2789,23 +3089,15 @@ def _find_mismatched_keys(
# loading checkpoint
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
- state_dict = _convert_state_dict(model, state_dict, prefix)
- # In the original PyTorch implementation, this transformation is done via `_register_load_state_dict_pre_hook(load_hook)`.
- # Since MindSpore does not support such hooks, we manually apply the same renaming logic here.
- # This ensures compatibility with checkpoints loaded using the original naming convention.
- if "Mamba" in model.__class__.__name__:
- state_dict_tmp = {}
- for k, v in state_dict.items():
- new_k = k.replace("embedding.", "embeddings.") if "embedding." in k else k
- state_dict_tmp[new_k] = v
- state_dict = state_dict_tmp
-
+ # checkpoint mapping from pt to hf
matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s]
if matching:
# Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta
state_dict = {
key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping
}
+ # checkpoint mapping from hf to ms
+ state_dict = _convert_state_dict(model, state_dict, prefix)
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
@@ -2875,10 +3167,23 @@ def _find_mismatched_keys(
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+ def get_compiled_call(self) -> Callable:
+ """Return a `mindspore.jit`'d version of `self.__call__`. This is useful to dynamically choose between
+ non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
+ want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
+ (where we want the speed-ups of compiled version with static shapes)."""
+ # Only reset it if not present or different from previous config
+ if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
+ return self.__call__
+ raise NotImplementedError(
+ "mindone.transformers does not support operator like 'torch.compile' right now."
+ "Please add @jit decorator in model constuct func instead!"
+ )
+
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = {".".join(key.split(".")[:-1]) for key in names}
- # torch.nn.ParameterList is a special case where two parameter keywords
+ # torch.nn.Parameterlist is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union(
{".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
@@ -3043,9 +3348,9 @@ def construct(
), "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
- start_positions = start_positions[:, None, None].expand((-1, -1, hsz)) # shape (bsz, 1, hsz)
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
- start_states = start_states.expand((-1, slen, -1)) # shape (bsz, slen, hsz)
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(mint.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
@@ -3110,11 +3415,11 @@ def construct(
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
- start_positions = start_positions[:, None, None].expand((-1, -1, hsz)) # shape (bsz, 1, hsz)
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
if cls_index is not None:
- cls_index = cls_index[:, None, None].expand((-1, -1, hsz)) # shape (bsz, 1, hsz)
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
@@ -3190,7 +3495,7 @@ def construct(
is_impossible: Optional[ms.Tensor] = None,
p_mask: Optional[ms.Tensor] = None,
return_dict: bool = False,
- ) -> Union[SquadHeadOutput, Tuple[ms.Tensor]]:
+ ) -> Union[SquadHeadOutput, tuple[ms.Tensor]]:
"""
Args:
hidden_states (`mindspore.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
@@ -3246,9 +3551,9 @@ def construct(
start_top_log_probs, start_top_index = mint.topk(
start_log_probs, self.start_n_top, dim=-1
) # shape (bsz, start_n_top)
- start_top_index_exp = start_top_index.unsqueeze(-1).expand((-1, -1, hsz)) # shape (bsz, start_n_top, hsz)
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = mint.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
- start_states = start_states.unsqueeze(1).expand((-1, slen, -1, -1)) # shape (bsz, slen, start_n_top, hsz)
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
start_states
@@ -3416,7 +3721,7 @@ def __len__(self):
def register(cls, key: str, value: Callable):
cls._global_mapping.update({key: value})
- def valid_keys(self) -> List[str]:
+ def valid_keys(self) -> list[str]:
return list(self.keys())
diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py
index 177b6e5cc0..28e7ced270 100644
--- a/mindone/transformers/models/__init__.py
+++ b/mindone/transformers/models/__init__.py
@@ -273,3 +273,6 @@
if version.parse(transformers.__version__) >= version.parse("4.53.0"):
from . import glm4v, minimax, qwen2_5_omni, vjepa2
+
+if version.parse(transformers.__version__) >= version.parse("4.57.0"):
+ from . import qwen3_vl, qwen3_vl_moe
diff --git a/mindone/transformers/models/aria/modeling_aria.py b/mindone/transformers/models/aria/modeling_aria.py
index b9c9e5cefc..e3716ef53f 100644
--- a/mindone/transformers/models/aria/modeling_aria.py
+++ b/mindone/transformers/models/aria/modeling_aria.py
@@ -26,7 +26,6 @@
from transformers import AriaConfig, AriaTextConfig
from transformers.utils import (
- LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@@ -50,6 +49,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, MSPreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
# from ..auto import AutoModelForCausalLM, AutoModel
from ..idefics3 import Idefics3VisionTransformer
@@ -546,7 +546,7 @@ def construct(
attention_mask: Optional[ms.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[ms.Tensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
+ **kwargs: Unpack[TransformersKwargs],
) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -621,7 +621,7 @@ def construct(
use_cache: Optional[bool] = False,
cache_position: Optional[ms.Tensor] = None,
position_embeddings: Optional[Tuple[ms.Tensor, ms.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ **kwargs: Unpack[TransformersKwargs],
) -> Tuple[ms.Tensor, Optional[Tuple[ms.Tensor, ms.Tensor]]]:
residual = hidden_states
@@ -909,7 +909,7 @@ def construct(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -966,7 +966,7 @@ def construct(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
- **flash_attn_kwargs,
+ **kwargs,
)
hidden_states = layer_outputs[0]
@@ -1078,10 +1078,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
"""
Aria model for causal language modeling tasks.
@@ -1142,7 +1138,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1183,7 +1179,7 @@ def construct(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
+ outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -1197,7 +1193,7 @@ def construct(
**kwargs,
)
- hidden_states = outputs[0]
+ hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
@@ -1206,10 +1202,6 @@ def construct(
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
@@ -1311,40 +1303,149 @@ class AriaCausalLMOutputWithPast(ModelOutput):
"""
-@add_start_docstrings(
- """Aria model for conditional generation tasks.
+@dataclass
+class AriaModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`ms.Tensor`, *optional*):
+ A `ms.Tensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
- This model combines a vision tower, a multi-modal projector, and a language model
- to perform tasks that involve both image and text inputs.""",
- ARIA_START_DOCSTRING,
-)
-class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
- config_class = AriaConfig
- _supports_flash_attn_2 = False
- _supports_flex_attn = False
- _supports_sdpa = False
- _tied_weights_keys = ["language_model.lm_head.weight"]
+ image_hidden_states: Optional[ms.Tensor] = None
+
+
+class AriaModel(AriaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: AriaConfig):
super().__init__(config)
-
# self.vision_tower = AutoModel.from_config(config.vision_config)
self.vision_tower = Idefics3VisionTransformer(config.vision_config)
-
self.multi_modal_projector = AriaProjector(config)
- self.vocab_size = config.text_config.vocab_size
+
# self.language_model = AutoModelForCausalLM.from_config(config.text_config) # AriaTextForCausalLM
# OR
- self.language_model = AriaTextForCausalLM(config.text_config)
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
- self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
+ self.language_model = AriaTextModel(config.text_config)
self.post_init()
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(
+ self,
+ pixel_values: ms.Tensor,
+ pixel_mask: ms.Tensor = None,
+ vision_feature_layer: int = -1,
+ ):
+ patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
+ image_outputs = self.vision_tower(
+ pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
+ )
+ image_attn_mask = None
+ if patch_attention_mask is not None:
+ flattened_mask = patch_attention_mask.flatten(1)
+ image_attn_mask = mint.logical_not(flattened_mask)
+
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+ image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
+ return image_features
+
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ pixel_values: ms.Tensor = None,
+ pixel_mask: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ logits_to_keep: Union[int, ms.Tensor] = 0,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, AriaModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and images
+ if pixel_values is not None and inputs_embeds.shape[1] != 1:
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ ms.tensor(self.config.image_token_index, dtype=ms.int64)
+ )
+ n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ else:
+ image_embeds = input_ids == self.config.image_token_index
+ special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds)
+ n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=self.config.vision_feature_layer,
+ )
+ n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
+ n_image_features = n_images * n_features_per_image
+ if n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+
+ inputs_embeds = (
+ inputs_embeds.float().masked_scatter(special_image_mask, image_features.float()).to(inputs_embeds.dtype)
+ )
+
+ if logits_to_keep is None:
+ logits_to_keep = 0
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ logits_to_keep=logits_to_keep,
+ cache_position=cache_position,
+ )
+
+ return AriaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
return None
- # torch.tensor.unfold x 2: (B, H, W) => (B, H', W', K, K)
+ # ms.Tensor.unfold x 2: (B, H, W) => (B, H', W', K, K)
# patches_subgrid = pixel_mask.unfold(
# dimension=1,
# size=self.vision_tower.config.patch_size,
@@ -1370,42 +1471,69 @@ def _create_patch_attention_mask(self, pixel_mask):
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
+
+@add_start_docstrings(
+ """Aria model for conditional generation tasks.
+
+ This model combines a vision tower, a multi-modal projector, and a language model
+ to perform tasks that involve both image and text inputs.""",
+ ARIA_START_DOCSTRING,
+)
+class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: AriaConfig):
+ super().__init__(config)
+
+ self.model = AriaModel(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
+ return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
-
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
+ self.model.set_input_embeddings(value)
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
+ def get_output_embeddings(self) -> nn.Cell:
+ return self.lm_head
def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
+ self.model.set_decoder(decoder)
def get_decoder(self):
- return self.language_model.get_decoder()
+ return self.model.get_decoder
def get_image_features(
self,
pixel_values: ms.Tensor,
- pixel_mask: ms.Tensor = None,
+ pixel_mask: Optional[ms.Tensor] = None,
vision_feature_layer: int = -1,
):
- patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
- image_outputs = self.vision_tower(
- pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ vision_feature_layer=vision_feature_layer,
)
- image_attn_mask = None
- if patch_attention_mask is not None:
- flattened_mask = patch_attention_mask.flatten(1)
- image_attn_mask = mint.logical_not(flattened_mask)
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
- image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
- return image_features
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
@@ -1425,7 +1553,7 @@ def construct(
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
cache_position: Optional[ms.Tensor] = None,
- **loss_kwargs,
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1503,39 +1631,13 @@ def construct(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- # 2. Merge text and images
- if pixel_values is not None and inputs_embeds.shape[1] != 1:
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- ms.tensor(self.config.image_token_index, dtype=ms.int64)
- )
- n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
- else:
- image_embeds = input_ids == self.config.image_token_index
- special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds)
- n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- pixel_mask=pixel_mask,
- vision_feature_layer=self.config.vision_feature_layer,
- )
- n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
- n_image_features = n_images * n_features_per_image
- if n_image_tokens != n_image_features:
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
-
- inputs_embeds = (
- inputs_embeds.float().masked_scatter(special_image_mask, image_features.float()).to(inputs_embeds.dtype)
- )
-
if logits_to_keep is None:
logits_to_keep = 0
- outputs = self.language_model(
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -1544,22 +1646,21 @@ def construct(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
- logits_to_keep=logits_to_keep,
cache_position=cache_position,
+ **kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
return AriaCausalLMOutputWithPast(
loss=loss,
logits=logits,
diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py
index 42e03184dd..b1d25744a3 100644
--- a/mindone/transformers/models/auto/configuration_auto.py
+++ b/mindone/transformers/models/auto/configuration_auto.py
@@ -632,6 +632,24 @@
CONFIG_MAPPING_NAMES.update({"minimax": "MiniMaxConfig", "vjepa2": "VJEPA2Model"})
MODEL_NAMES_MAPPING.update({"minimax": "MiniMax", "vjepa2": "VJEPA2Model"})
+if version.parse(transformers.__version__) >= version.parse("4.57.0"):
+ CONFIG_MAPPING_NAMES.update(
+ {
+ ("qwen3_vl", "Qwen3VLConfig"),
+ ("qwen3_vl_moe", "Qwen3VLMoeConfig"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
+ ("qwen3_vl_text", "Qwen3VLTextConfig"),
+ }
+ )
+ MODEL_NAMES_MAPPING.update(
+ {
+ ("qwen3_vl", "Qwen3VL"),
+ ("qwen3_vl_moe", "Qwen3VLMoe"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoe"),
+ ("qwen3_vl_text", "Qwen3VL"),
+ }
+ )
+
def model_type_to_module_name(key):
"""Converts a config key to the corresponding module."""
diff --git a/mindone/transformers/models/auto/image_processing_auto.py b/mindone/transformers/models/auto/image_processing_auto.py
index 517809e7d2..baea350d79 100644
--- a/mindone/transformers/models/auto/image_processing_auto.py
+++ b/mindone/transformers/models/auto/image_processing_auto.py
@@ -85,6 +85,9 @@
if version.parse(transformers.__version__) >= version.parse("4.53.0"):
IMAGE_PROCESSOR_MAPPING_NAMES.update({"glm4v": ("Glm4vImageProcessor",)})
+if version.parse(transformers.__version__) >= version.parse("4.57.0"):
+ IMAGE_PROCESSOR_MAPPING_NAMES.update({"qwen3_vl": ("Qwen2VLImageProcessor",)})
+
for model_type, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
slow_image_processor_class, *fast_image_processor_class = image_processors
if not is_vision_available():
diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py
index b71af89b26..c89e50bc47 100644
--- a/mindone/transformers/models/auto/modeling_auto.py
+++ b/mindone/transformers/models/auto/modeling_auto.py
@@ -1301,6 +1301,25 @@
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.update({"minimax": "MiniMaxForQuestionAnswering"})
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES.update({"minimax": "MiniMaxForTokenClassification"})
+if version.parse(transformers.__version__) >= version.parse("4.57.0"):
+ MODEL_MAPPING_NAMES.update(
+ {
+ ("qwen3_vl", "Qwen3VLModel"),
+ ("qwen3_vl_moe", "Qwen3VLMoeModel"),
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"),
+ ("qwen3_vl_text", "Qwen3VLTextModel"),
+ }
+ )
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.update(
+ {
+ ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
+ ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
+ }
+ )
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.update(
+ {("qwen3_vl", "Qwen3VLForConditionalGeneration"), ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration")}
+ )
+
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
diff --git a/mindone/transformers/models/auto/processing_auto.py b/mindone/transformers/models/auto/processing_auto.py
index decc1e1448..88df3d8aee 100644
--- a/mindone/transformers/models/auto/processing_auto.py
+++ b/mindone/transformers/models/auto/processing_auto.py
@@ -76,6 +76,9 @@
if version.parse(transformers.__version__) >= version.parse("4.53.0"):
PROCESSOR_MAPPING_NAMES.update({"glm4v": "Glm4vProcessor"})
+if version.parse(transformers.__version__) >= version.parse("4.57.0"):
+ PROCESSOR_MAPPING_NAMES.update({"qwen3_vl": "Qwen3VLProcessor", "qwen3_vl_moe": "Qwen3VLProcessor"})
+
PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
diff --git a/mindone/transformers/models/aya_vision/modeling_aya_vision.py b/mindone/transformers/models/aya_vision/modeling_aya_vision.py
index 216188fd08..82c79fa8cd 100644
--- a/mindone/transformers/models/aya_vision/modeling_aya_vision.py
+++ b/mindone/transformers/models/aya_vision/modeling_aya_vision.py
@@ -31,10 +31,13 @@
from mindspore import mint, nn
from ...activations import ACT2FN
+from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...modeling_outputs import ModelOutput
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
-from ..auto import AutoModel, AutoModelForCausalLM
+from ...processing_utils import Unpack
+from ..auto import AutoModel
_CONFIG_FOR_DOC = "AyaVisionConfig"
@@ -243,20 +246,32 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput):
"""
-class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
+@dataclass
+class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[ms.Tensor] = None
+
+
+class AyaVisionModel(AyaVisionPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: AyaVisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -323,6 +338,154 @@ def get_image_features(
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ pixel_values: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ logits_to_keep: Union[int, ms.Tensor] = 0,
+ image_sizes: Optional[ms.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, AyaVisionModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds)
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.dtype)
+ if inputs_embeds.dtype == ms.bfloat16:
+ inputs_embeds = (
+ inputs_embeds.half().masked_scatter(special_image_mask, image_features.half()).to(ms.bfloat16)
+ )
+ else:
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ return AyaVisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: AyaVisionConfig):
+ super().__init__(config)
+ self.model = AyaVisionModel(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Cell:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder
+
+ def get_image_features(
+ self,
+ pixel_values: ms.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ **kwargs,
+ )
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
@@ -405,56 +568,28 @@ def construct(
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- vision_feature_layer=vision_feature_layer,
- vision_feature_select_strategy=vision_feature_select_strategy,
- image_sizes=image_sizes,
- )
-
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds)
- if inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_index).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.dtype)
- if inputs_embeds.dtype == ms.bfloat16:
- inputs_embeds = (
- inputs_embeds.half().masked_scatter(special_image_mask, image_features.half()).to(ms.bfloat16)
- )
- else:
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ image_sizes=image_sizes,
**lm_kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@@ -482,7 +617,7 @@ def construct(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
diff --git a/mindone/transformers/models/blip/image_processing_blip_fast.py b/mindone/transformers/models/blip/image_processing_blip_fast.py
index 9f2811baf3..5d3c4ac390 100644
--- a/mindone/transformers/models/blip/image_processing_blip_fast.py
+++ b/mindone/transformers/models/blip/image_processing_blip_fast.py
@@ -17,16 +17,13 @@
# limitations under the License.
"""Fast Image processor class for BLIP."""
-from transformers.utils import add_start_docstrings
+from transformers.utils import auto_docstring
-from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
+from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
-@add_start_docstrings(
- "Constructs a fast BLIP image processor.",
- BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
-)
+@auto_docstring
class BlipImageProcessorFast(BaseImageProcessorFast):
# To be checked against the slow image processor
# None values left after checking can be removed
diff --git a/mindone/transformers/models/chameleon/processing_chameleon.py b/mindone/transformers/models/chameleon/processing_chameleon.py
index 120afed1f1..1b9394afbd 100644
--- a/mindone/transformers/models/chameleon/processing_chameleon.py
+++ b/mindone/transformers/models/chameleon/processing_chameleon.py
@@ -20,13 +20,12 @@
"""
from typing import List, Optional, Union
+import numpy as np
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
-import mindspore as ms
-
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
-from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
class ChameleonTextKwargs(TextKwargs, total=False):
@@ -39,6 +38,7 @@ class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
+ "return_mm_token_type_ids": False,
},
"common_kwargs": {
"return_tensors": "ms",
@@ -67,16 +67,21 @@ class ChameleonProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
- valid_kwargs = ["image_seq_length", "image_token"]
+
image_processor_class = "ChameleonImageProcessor"
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""):
self.image_seq_length = image_seq_length
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.image_start_token = (
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else ""
) # fixed tokens for start and end, so can hardcode
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else ""
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token)
+ self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
+ self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id]
super().__init__(image_processor, tokenizer)
@@ -92,7 +97,7 @@ def __call__(
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
- CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
@@ -120,8 +125,7 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
- # check if images and text inputs are reversed for BC
- images, text = _validate_images_text_input_order(images, text)
+
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
@@ -145,15 +149,45 @@ def __call__(
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)
- output_kwargs["text_kwargs"].pop("return_tensors", None)
- data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors="np")
- for k, v in data.items():
- data[k] = ms.tensor(v)
-
+ image_inputs = {}
if images is not None:
- data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+
+ Args:
+ image_sizes (`list[list[int]]`, *optional*):
+ The input sizes formatted as (height, width) per each image.
+
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+
+ vision_data = {}
+ if image_sizes is not None:
+ # add 2 for BOI and EOI tokens
+ num_image_tokens = [self.image_seq_length + 2] * len(image_sizes)
+ num_image_patches = [1] * len(image_sizes)
+
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
- return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])
+ return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
diff --git a/mindone/transformers/models/clvp/modeling_clvp.py b/mindone/transformers/models/clvp/modeling_clvp.py
index 4f16d987f2..b9921b0efd 100644
--- a/mindone/transformers/models/clvp/modeling_clvp.py
+++ b/mindone/transformers/models/clvp/modeling_clvp.py
@@ -21,27 +21,22 @@
import copy
import math
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Callable, Optional, Union
-from transformers import ClvpConfig, ClvpDecoderConfig, ClvpEncoderConfig
from transformers.generation import GenerationConfig
-from transformers.utils import (
- ModelOutput,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
-)
+from transformers.models.clvp.configuration_clvp import ClvpConfig, ClvpDecoderConfig, ClvpEncoderConfig
+from transformers.utils import auto_docstring
import mindspore as ms
-from mindspore import mint, nn, ops
+from mindspore import Parameter, mint, nn, ops
from mindspore.mint.nn import CrossEntropyLoss
from mindone.models.utils import normal_, ones_, zeros_
-from ...activations import ACT2FN
+from ...activations import ACT2FN, get_activation
+from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
-from ...mindspore_utils import Conv1D, isin_mps_friendly
+from ...mindspore_utils import Conv1D
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
@@ -49,12 +44,11 @@
BaseModelOutputWithPooling,
CausalLMOutputWithCrossAttentions,
)
-from ...modeling_utils import PreTrainedModel, SequenceSummary
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, logging
logger = logging.get_logger(__name__)
-_CHECKPOINT_FOR_DOC = "susnato/clvp_dev"
-
# Copied from transformers.models.clip.modeling_clip.contrastive_loss
def contrastive_loss(logits: ms.Tensor) -> ms.Tensor:
@@ -131,14 +125,8 @@ def _pad_extra_bos_eos_tokens(
modified_input_ids = mint.zeros((input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype)
for i, each_input_id in enumerate(input_ids):
# locate where the valid tokens end and then add the eos token
- if isin_mps_friendly(each_input_id, ms.tensor(pad_token_id)).sum():
- pos = mint.where(each_input_id == pad_token_id)[0].min()
- modified_input_ids[i] = mint.concat(
- [each_input_id[:pos], ms.tensor([eos_token_id], dtype=input_ids.dtype), each_input_id[pos:]]
- )
- else:
- # if there are no pad tokens present, then add eos to the end
- modified_input_ids[i] = mint.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
+ # if there are no pad tokens present, then add eos to the end
+ modified_input_ids[i] = mint.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
attention_mask = (
mint.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
)
@@ -147,78 +135,72 @@ def _pad_extra_bos_eos_tokens(
@dataclass
-class ClvpEncoderOutput(ModelOutput):
- """
+@auto_docstring(
+ custom_intro="""
Base class for CLVP encoder's outputs that contains a pooling of the last hidden states as well as a projection
output (a linear layer on top of the pooled output).
-
- Args:
- embeds (`ms.Tensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`):
- The embeddings obtained by applying the projection layer to the pooler_output.
- last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
- The hidden state of the last layer of the model.
- pooler_output (`ms.Tensor` of shape `(batch_size, hidden_size)`):
- Pooled output of the `last_hidden_state`.
- hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
- the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
- the self-attention heads.
+ """
+)
+class ClvpEncoderOutput(ModelOutput):
+ r"""
+ embeds (`ms.Tensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`):
+ The embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ The hidden state of the last layer of the model.
+ pooler_output (`ms.Tensor` of shape `(batch_size, hidden_size)`):
+ Pooled output of the `last_hidden_state`.
"""
embeds: Optional[ms.Tensor] = None
- last_hidden_state: ms.Tensor = None
+ last_hidden_state: Optional[ms.Tensor] = None
pooler_output: Optional[ms.Tensor] = None
- hidden_states: Optional[Tuple[ms.Tensor]] = None
- attentions: Optional[Tuple[ms.Tensor]] = None
+ hidden_states: Optional[tuple[ms.Tensor]] = None
+ attentions: Optional[tuple[ms.Tensor]] = None
@dataclass
+@auto_docstring
class ClvpOutput(ModelOutput):
- """
- Args:
- loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
- Contrastive loss for speech-text similarity.
- speech_ids (`ms.Tensor`, *optional*):
- speech_ids (or speech candidates) generated by the `ClvpForCausalLM` model.
- logits_per_speech (`ms.Tensor` of shape `(speech_batch_size, text_batch_size)`):
- The scaled dot product scores between `speech_embeds` and `text_embeds`. This represents the speech-text
- similarity scores.
- logits_per_text (`ms.Tensor` of shape `(text_batch_size, speech_batch_size)`):
- The scaled dot product scores between `text_embeds` and `speech_embeds`. This represents the text-speech
- similarity scores.
- text_embeds (`ms.Tensor` of shape `(batch_size, output_dim`):
- The text embeddings obtained by applying the projection layer to the pooled output of the text encoder
- model.
- speech_embeds (`ms.Tensor` of shape `(batch_size, output_dim`):
- The speech embeddings obtained by applying the projection layer to the pooled output of the speech encoder
- model.
- text_model_output (`BaseModelOutputWithPooling`):
- The pooled output of the `last_hidden_state` of the text encoder Model.
- speech_model_output (`BaseModelOutputWithPooling`):
- The pooled output of the `last_hidden_state` of the speech encoder Model.
- decoder_hidden_states (`ms.Tensor`, *optional*):
- The hidden states of the decoder model.
- text_encoder_hidden_states (`ms.Tensor`, *optional*):
- The hidden states of the text encoder model.
- speech_encoder_hidden_states (`ms.Tensor`, *optional*):
- The hidden states of the speech encoder model.
+ r"""
+ loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for speech-text similarity.
+ speech_ids (`ms.Tensor`, *optional*):
+ speech_ids (or speech candidates) generated by the `ClvpForCausalLM` model.
+ logits_per_speech (`ms.Tensor` of shape `(speech_batch_size, text_batch_size)`):
+ The scaled dot product scores between `speech_embeds` and `text_embeds`. This represents the speech-text
+ similarity scores.
+ logits_per_text (`ms.Tensor` of shape `(text_batch_size, speech_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `speech_embeds`. This represents the text-speech
+ similarity scores.
+ text_embeds (`ms.Tensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of the text encoder
+ model.
+ speech_embeds (`ms.Tensor` of shape `(batch_size, output_dim`):
+ The speech embeddings obtained by applying the projection layer to the pooled output of the speech encoder
+ model.
+ text_model_output (`BaseModelOutputWithPooling`):
+ The pooled output of the `last_hidden_state` of the text encoder Model.
+ speech_model_output (`BaseModelOutputWithPooling`):
+ The pooled output of the `last_hidden_state` of the speech encoder Model.
+ decoder_hidden_states (`ms.Tensor`, *optional*):
+ The hidden states of the decoder model.
+ text_encoder_hidden_states (`ms.Tensor`, *optional*):
+ The hidden states of the text encoder model.
+ speech_encoder_hidden_states (`ms.Tensor`, *optional*):
+ The hidden states of the speech encoder model.
"""
loss: Optional[ms.Tensor] = None
speech_ids: Optional[ms.Tensor] = None
- logits_per_speech: ms.Tensor = None
- logits_per_text: ms.Tensor = None
- text_embeds: ms.Tensor = None
- speech_embeds: ms.Tensor = None
+ logits_per_speech: Optional[ms.Tensor] = None
+ logits_per_text: Optional[ms.Tensor] = None
+ text_embeds: Optional[ms.Tensor] = None
+ speech_embeds: Optional[ms.Tensor] = None
text_model_output: BaseModelOutputWithPooling = None
speech_model_output: BaseModelOutputWithPooling = None
- decoder_hidden_states: ms.Tensor = None
- text_encoder_hidden_states: ms.Tensor = None
- speech_encoder_hidden_states: ms.Tensor = None
+ decoder_hidden_states: Optional[ms.Tensor] = None
+ text_encoder_hidden_states: Optional[ms.Tensor] = None
+ speech_encoder_hidden_states: Optional[ms.Tensor] = None
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Clvp
@@ -228,7 +210,7 @@ def __init__(self, hidden_size, eps=1e-6):
ClvpRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
- self.weight = ms.Parameter(mint.ones(hidden_size))
+ self.weight = Parameter(mint.ones(hidden_size))
self.variance_epsilon = eps
def construct(self, hidden_states):
@@ -245,13 +227,13 @@ def extra_repr(self):
class ClvpRotaryPositionalEmbedding(nn.Cell):
"""
Rotary Position Embedding Class for CLVP. It was proposed in the paper 'ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
- POSITION EMBEDDING', Please see https://arxiv.org/pdf/2104.09864v1.pdf .
+ POSITION EMBEDDING', Please see https://huggingface.co/papers/2104.09864v1.pdf .
"""
def __init__(self, config):
super().__init__()
dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
- inv_freq = 1.0 / (10000 ** (mint.arange(0, dim, 2, dtype=ms.int32).float() / dim))
+ inv_freq = 1.0 / (10000 ** (mint.arange(0, dim, 2, dtype=ms.int64).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.cached_sequence_length = None
@@ -277,7 +259,7 @@ class ClvpSelfAttention(nn.Cell):
Multi-headed attention to combine Absolute and Rotary Positional Embeddings into a single Attention module.
"""
- def __init__(self, config):
+ def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -290,6 +272,7 @@ def __init__(self, config):
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
+ self.layer_idx = layer_idx
if hasattr(config, "max_position_embeddings"):
max_positions = config.max_position_embeddings
@@ -302,9 +285,8 @@ def __init__(self, config):
self.q_proj = mint.nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
self.out_proj = mint.nn.Linear(self.embed_dim, self.embed_dim)
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention._shape
def _shape(self, tensor: ms.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).swapaxes(1, 2).contiguous()
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def construct(
self,
@@ -312,11 +294,12 @@ def construct(
rotary_pos_emb: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
- past_key_value: Optional[Tuple[ms.Tensor]] = None,
+ past_key_value: Optional[Cache] = None,
use_cache: Optional[bool] = False,
head_mask: Optional[ms.Tensor] = None,
output_attentions: Optional[bool] = False,
- ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> tuple[ms.Tensor, Optional[ms.Tensor], Optional[tuple[ms.Tensor]]]:
# Raise error when position_ids is None but rotary_pos_emb is provided, because we need that when applying
# rotary_pos_emb to query and key states.
if rotary_pos_emb is not None and position_ids is None:
@@ -330,14 +313,9 @@ def construct(
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if past_key_value is not None:
- past_key, past_value = past_key_value
- key_states = mint.cat((past_key, key_states), dim=-2)
- value_states = mint.cat((past_value, value_states), dim=-2)
-
- if use_cache is True:
- present = (key_states, value_states)
- else:
- present = None
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
if rotary_pos_emb is not None:
rotary_emb_dim = rotary_pos_emb.shape[-1]
@@ -366,7 +344,7 @@ def construct(
tgt_len = query_states.shape[2]
src_len = key_states.shape[2]
- attn_weights = mint.matmul(query_states, key_states.swapaxes(2, 3))
+ attn_weights = mint.matmul(query_states, key_states.transpose(2, 3))
if attention_mask is not None:
if attention_mask.shape != (bsz, 1, tgt_len, src_len):
@@ -390,15 +368,12 @@ def construct(
f" {attn_output.shape}"
)
- attn_output = attn_output.swapaxes(1, 2).contiguous()
+ attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
- if not output_attentions:
- attn_weights = None
-
- return attn_output, present, attn_weights
+ return attn_output, attn_weights
class ClvpGatedLinearUnit(nn.Cell):
@@ -428,7 +403,7 @@ def __init__(self, config):
self.fc1 = ClvpGatedLinearUnit(config)
self.fc2 = mint.nn.Linear(config.intermediate_size, config.hidden_size)
- self.dropout_layer = mint.nn.Dropout(config.dropout)
+ self.dropout_layer = nn.Dropout(config.dropout)
def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
hidden_states = self.fc1(hidden_states)
@@ -455,7 +430,7 @@ def construct(
attention_mask: ms.Tensor,
position_ids: ms.Tensor,
output_attentions: Optional[bool] = False,
- ) -> Tuple[ms.Tensor]:
+ ) -> tuple[ms.Tensor]:
"""
Args:
hidden_states (`ms.Tensor` of shape `(batch, seq_len, embed_dim)`):
@@ -474,7 +449,7 @@ def construct(
hidden_states = self.input_rmsnorm(hidden_states)
- attention_outputs = self.self_attn(
+ hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
@@ -482,8 +457,6 @@ def construct(
output_attentions=output_attentions,
)
- hidden_states = attention_outputs[0]
-
hidden_states = residual + hidden_states
residual = hidden_states
@@ -491,12 +464,105 @@ def construct(
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
- outputs = (hidden_states,)
+ return hidden_states, attn_weights
+
+
+# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp
+class ClvpSequenceSummary(nn.Cell):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`ClvpConfig`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+ - `"last"` -- Take the last token hidden state (like XLNet)
+ - `"first"` -- Take the first token hidden state (like Bert)
+ - `"mean"` -- Take the mean of all tokens hidden states
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+ - `"attn"` -- Not implemented now, use multi-head attention
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ def __init__(self, config: ClvpConfig):
+ super().__init__()
+
+ self.summary_type = getattr(config, "summary_type", "last")
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.summary = nn.Identity()
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = mint.nn.Linear(config.hidden_size, num_classes)
- if output_attentions:
- outputs += (attention_outputs[-1],)
+ activation_string = getattr(config, "summary_activation", None)
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
- return outputs
+ self.first_dropout = nn.Identity()
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+ self.last_dropout = nn.Identity()
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+ def construct(self, hidden_states: ms.Tensor, cls_index: Optional[ms.Tensor] = None) -> ms.Tensor:
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`ms.Tensor` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`ms.Tensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `ms.Tensor`: The summary of the sequence hidden states.
+ """
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = hidden_states.mean(dim=1)
+ elif self.summary_type == "cls_index":
+ if cls_index is None:
+ cls_index = mint.full_like(
+ hidden_states[..., :1, :],
+ hidden_states.shape[-2] - 1,
+ dtype=ms.int64,
+ )
+ else:
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.shape[-1],))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ output = self.first_dropout(output)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output)
+
+ return output
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP
@@ -507,9 +573,9 @@ def __init__(self, intermediate_size, config):
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
- self.dropout = mint.nn.Dropout(config.resid_pdrop)
+ self.dropout = nn.Dropout(config.resid_pdrop)
- def construct(self, hidden_states: Optional[Tuple[ms.Tensor]]) -> ms.Tensor:
+ def construct(self, hidden_states: Optional[tuple[ms.Tensor]]) -> ms.Tensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
@@ -518,27 +584,28 @@ def construct(self, hidden_states: Optional[Tuple[ms.Tensor]]) -> ms.Tensor:
class ClvpDecoderLayer(nn.Cell):
- def __init__(self, config):
+ def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.input_layernorm = mint.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = ClvpSelfAttention(config)
+ self.attn = ClvpSelfAttention(config, layer_idx=layer_idx)
self.post_attention_layernorm = mint.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = ClvpDecoderMLP(inner_dim, config)
def construct(
self,
- hidden_states: Optional[Tuple[ms.Tensor]],
- past_key_value: Optional[Tuple[ms.Tensor]] = None,
+ hidden_states: Optional[tuple[ms.Tensor]],
+ past_key_value: Optional[Cache] = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
head_mask: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
- ) -> Union[Tuple[ms.Tensor], Optional[Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]]]:
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[tuple[ms.Tensor], Optional[tuple[ms.Tensor, tuple[ms.Tensor, ...]]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.attn(
@@ -549,24 +616,19 @@ def construct(
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
attn_output = attn_outputs[0]
- outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
- feed_construct_hidden_states = self.mlp(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
- hidden_states = residual + feed_construct_hidden_states
-
- if use_cache:
- outputs = (hidden_states,) + outputs
- else:
- outputs = (hidden_states,) + outputs[1:]
+ hidden_states = residual + feed_forward_hidden_states
- return outputs
+ return (hidden_states,) + attn_outputs[1:]
class ClvpConditioningEncoder(nn.Cell):
@@ -653,7 +715,7 @@ def construct(
# construct attention mask if not given
if attention_mask is None:
- attention_mask = mint.ones([batch_size, seq_length], dtype=ms.int32)
+ attention_mask = mint.ones([batch_size, seq_length], dtype=ms.int64)
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
@@ -669,24 +731,27 @@ def construct(
position_embeds = self.text_position_embedding(position_ids)
text_embeds = inputs_embeds + position_embeds
- # process each log-mel spectrogram into a single vector
- mel_spec = self.mel_conv(input_features)
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+ else:
+ # process each log-mel spectrogram into a single vector
+ mel_spec = self.mel_conv(input_features)
- for i, mel_attn_block in enumerate(self.mel_attn_blocks):
- residual_mel_spec = mel_spec.swapaxes(1, 2)
+ for i, mel_attn_block in enumerate(self.mel_attn_blocks):
+ residual_mel_spec = mel_spec.transpose(1, 2)
- mel_spec = self.group_norms[i](mel_spec).swapaxes(1, 2)
- mel_spec = mel_attn_block(mel_spec)[0] + residual_mel_spec
- mel_spec = mel_spec.swapaxes(1, 2)
+ mel_spec = self.group_norms[i](mel_spec).transpose(1, 2)
+ mel_spec = mel_attn_block(mel_spec)[0] + residual_mel_spec
+ mel_spec = mel_spec.transpose(1, 2)
mel_spec = mel_spec[:, :, 0]
mel_spec = mel_spec.unsqueeze(1)
# repeat if there is either (1 text vs N audios) or (N texts vs 1 audio)
if text_embeds.shape[0] == 1 and mel_spec.shape[0] != 1:
- text_embeds = text_embeds.tile((mel_spec.shape[0], 1, 1))
+ text_embeds = text_embeds.repeat(mel_spec.shape[0], 1, 1)
elif text_embeds.shape[0] != 1 and mel_spec.shape[0] == 1:
- mel_spec = mel_spec.tile((text_embeds.shape[0], 1, 1))
+ mel_spec = mel_spec.repeat(text_embeds.shape[0], 1, 1)
# If there is N texts and M audios we will raise error since the number of text and audio must be same.
elif text_embeds.shape[0] != mel_spec.shape[0]:
raise ValueError(
@@ -697,19 +762,14 @@ def construct(
return mint.concat([mel_spec, text_embeds], dim=1)
+@auto_docstring
class ClvpPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
-
- config_class = ClvpConfig
+ config: ClvpConfig
base_model_prefix = "clvp"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
- _supports_dynamic_input = True
- def _init_weights(self, module):
+ def _init_weights(self, module: nn.Cell):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, mint.nn.Embedding):
@@ -742,122 +802,6 @@ def _init_weights(self, module):
ones_(module.weight)
-CLVP_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a MindSpore [mindspore.mint.nn.Cell](https://www.mindspore.cn/docs/en/master/api_python/nn/mindspore.nn.Cell.html) subclass.
- Use it as a regular MindSpore Module and refer to the MindSpore documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`ClvpConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-CLVP_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- input_features (`ms.Tensor` of shape `(batch_size, feature_size, time_dim)`):
- Indicates log mel-spectrogram representations for audio returned by [`ClvpFeatureExtractor`].
- conditioning_encoder_inputs_embeds (`ms.Tensor`, *optional*):
- inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
- text_encoder_inputs_embeds (`ms.Tensor`, *optional*):
- inputs_embeds for the text encoder model passed in place of `input_ids`.
- attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding text token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
- return_loss (`bool`, *optional*):
- Whether or not to return the contrastive loss.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-CLVP_DECODER_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`ms.Tensor` of shape `(batch_size, input_ids_length)`):
- Indices of input sequence tokens in the vocabulary.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- past_key_values (`Tuple[Tuple[ms.Tensor]]` of length `config.n_layers`):
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
- their past given to this model should not be passed as `input_ids` as they have already been computed.
- attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
- `past_key_values`. In other words, the `attention_mask` always has to have the length:
- `len(past_key_values) + len(input_ids)`
-
- [What are attention masks?](../glossary#attention-mask)
- token_type_ids (`ms.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
-
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
-
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- head_mask (`ms.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
-
- inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
-
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
- `past_key_values`).
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
class ClvpEncoder(ClvpPreTrainedModel):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
@@ -875,7 +819,7 @@ def __init__(self, config: ClvpConfig):
self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None
self.layers = nn.CellList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.sequence_summary = SequenceSummary(config)
+ self.sequence_summary = ClvpSequenceSummary(config)
self.final_layer_norm = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.projection = mint.nn.Linear(config.hidden_size, config.projection_dim, bias=False)
@@ -899,7 +843,7 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutput]:
+ ) -> Union[tuple, BaseModelOutput]:
r"""
Args:
input_ids (`ms.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
@@ -953,7 +897,7 @@ def construct(
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
if position_ids is None:
- position_ids = mint.arange(input_shape[1], dtype=ms.int32)
+ position_ids = mint.arange(input_shape[1], dtype=ms.int64)
position_ids = position_ids.unsqueeze(0)
encoder_states = () if output_hidden_states else None
@@ -965,14 +909,16 @@ def construct(
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
-
- layer_outputs = encoder_layer(
- hidden_states,
- rotary_pos_emb,
- attention_mask,
- position_ids,
- output_attentions=output_attentions,
- )
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ rotary_pos_emb,
+ attention_mask,
+ position_ids,
+ output_attentions=output_attentions,
+ )
hidden_states = layer_outputs[0]
@@ -1018,8 +964,10 @@ def __init__(self, config):
self.input_embeds_layer = mint.nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.position_embeds_layer = mint.nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size)
- self.drop = mint.nn.Dropout(self.config.embd_pdrop)
- self.layers = nn.CellList([ClvpDecoderLayer(self.config) for _ in range(self.config.num_hidden_layers)])
+ self.drop = nn.Dropout(self.config.embd_pdrop)
+ self.layers = nn.CellList(
+ [ClvpDecoderLayer(self.config, layer_idx=i) for i in range(self.config.num_hidden_layers)]
+ )
self.layer_norm = mint.nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_epsilon)
self.gradient_checkpointing = False
@@ -1040,7 +988,7 @@ def _prune_heads(self, heads_to_prune):
for layer, heads in heads_to_prune.items():
self.layers[layer].attn.prune_heads(heads)
- @add_start_docstrings_to_model_forward(CLVP_DECODER_INPUTS_DOCSTRING)
+ @auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
@@ -1048,13 +996,14 @@ def construct(
token_type_ids: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
head_mask: Optional[ms.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None,
+ past_key_values: Optional[tuple[tuple[ms.Tensor]]] = None,
inputs_embeds: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1078,13 +1027,26 @@ def construct(
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
- if past_key_values is None:
- past_key_values_length = 0
- past_key_values = tuple([None] * len(self.layers))
- else:
- past_key_values_length = past_key_values[0][0].shape[-2]
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `DynamicCache` instead, e.g. "
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
+ )
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if position_ids is None:
- position_ids = mint.arange(past_key_values_length, input_shape[-1] + past_key_values_length, dtype=ms.int32)
+ position_ids = mint.arange(past_key_values_length, input_shape[-1] + past_key_values_length, dtype=ms.int64)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if inputs_embeds is None:
@@ -1112,39 +1074,33 @@ def construct(
output_shape = (-1,) + input_shape[1:] + (hidden_states.shape[-1],)
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
- for i, (block, past_key_value) in enumerate(zip(self.layers, past_key_values)):
+ for i, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
- outputs = block(
- hidden_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- position_ids=position_ids,
- head_mask=head_mask[i],
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+ else:
+ outputs = block(
+ hidden_states,
+ past_key_value=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
hidden_states = outputs[0]
- if use_cache is True:
- presents = presents + (outputs[1],)
if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ all_self_attentions = all_self_attentions + (outputs[1],)
if self.config.add_cross_attention:
- all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
hidden_states = self.layer_norm(hidden_states)
@@ -1154,26 +1110,26 @@ def construct(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
+ if return_legacy_cache:
+ past_key_values = past_key_values.to_legacy_cache()
+
if not return_dict:
return tuple(
v
- for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
- past_key_values=presents,
+ past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
-@add_start_docstrings(
- "The bare Clvp decoder model outputting raw hidden-states without any specific head on top.",
- CLVP_START_DOCSTRING,
-)
+@auto_docstring
class ClvpModel(ClvpPreTrainedModel):
def __init__(self, config: ClvpDecoderConfig):
super().__init__(config)
@@ -1192,7 +1148,7 @@ def set_input_embeddings(self, value):
def get_decoder(self):
return self.decoder
- @add_start_docstrings_to_model_forward(CLVP_DECODER_INPUTS_DOCSTRING)
+ @auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
@@ -1200,13 +1156,14 @@ def construct(
token_type_ids: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
head_mask: Optional[ms.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None,
+ past_key_values: Optional[tuple[tuple[ms.Tensor]]] = None,
inputs_embeds: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1227,6 +1184,7 @@ def construct(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
if not return_dict:
@@ -1241,9 +1199,10 @@ def construct(
)
-@add_start_docstrings(
- "The CLVP decoder model with a language modelling head on top.",
- CLVP_START_DOCSTRING,
+@auto_docstring(
+ custom_intro="""
+ The CLVP decoder model with a language modelling head on top.
+ """
)
class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
def __init__(self, config):
@@ -1258,6 +1217,11 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()
+ def get_output_embeddings(self):
+ # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
+ # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
+ return None
+
def get_input_embeddings(self):
return self.model.decoder.input_embeds_layer
@@ -1268,8 +1232,8 @@ def _prepare_model_inputs(
self,
inputs: Optional[ms.Tensor] = None,
bos_token_id: Optional[int] = None,
- model_kwargs: Optional[Dict[str, ms.Tensor]] = None,
- ) -> Tuple[ms.Tensor, Optional[str], Dict[str, ms.Tensor]]:
+ model_kwargs: Optional[dict[str, ms.Tensor]] = None,
+ ) -> tuple[ms.Tensor, Optional[str], dict[str, ms.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
@@ -1293,18 +1257,18 @@ def _prepare_model_inputs(
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# Check if conditioning_embeds are provided or not, if yes then concatenate the bos_token_id at the end of the conditioning_embeds.
- # Then we must subtract the positional_ids because during the construct pass it will be added anyways, so we must cancel them out here.
+ # Then we must subtract the positional_ids because during the forward pass it will be added anyways, so we must cancel them out here.
conditioning_embeds = model_kwargs.get("conditioning_embeds", None)
if conditioning_embeds is not None:
mel_start_token_embedding = self.model.decoder.input_embeds_layer(
- mint.full(
+ ops.full(
(conditioning_embeds.shape[0], 1),
fill_value=self.config.bos_token_id,
)
)
mel_start_token_embedding += self.model.decoder.position_embeds_layer(
- mint.full((conditioning_embeds.shape[0], 1), fill_value=0)
+ ops.full((conditioning_embeds.shape[0], 1), fill_value=0)
)
conditioning_embeds = mint.concat([conditioning_embeds, mel_start_token_embedding], dim=1)
@@ -1312,14 +1276,12 @@ def _prepare_model_inputs(
if hasattr(model_kwargs, "attention_mask"):
position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
else:
- position_ids = ops.range(
- 0, conditioning_embeds.shape[1], step=1
- ) # NOTE: usage different from torch.range
- position_ids = position_ids.unsqueeze(0).tile((conditioning_embeds.shape[0], 1))
+ position_ids = mint.arange(0, conditioning_embeds.shape[1], dtype=ms.int64)
+ position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)
model_kwargs["inputs_embeds"] = conditioning_embeds - self.model.decoder.position_embeds_layer(position_ids)
model_kwargs["input_ids"] = (
- mint.ones((model_kwargs["inputs_embeds"].shape[0], 1), dtype=ms.int32) * self.config.bos_token_id
+ mint.ones((model_kwargs["inputs_embeds"].shape[0], 1), dtype=ms.int64) * self.config.bos_token_id
)
return model_kwargs["inputs_embeds"], "inputs_embeds", model_kwargs
@@ -1328,63 +1290,35 @@ def _prepare_model_inputs(
return inputs, input_name, model_kwargs
def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, inputs_embeds=None, conditioning_embeds=None, **kwargs
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ conditioning_embeds=None,
+ cache_position=None,
+ **kwargs,
):
# Overwritten: has `conditioning_embeds`-related logic
input_ids_length = input_ids.shape[-1]
- token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
- if past_key_values:
- past_length = past_key_values[0][0].shape[2]
-
- # Some generation methods already pass only the last input ID
- if input_ids.shape[1] > past_length:
- remove_prefix_length = past_length
- else:
- # Default to old behavior: keep only final ID
- remove_prefix_length = input_ids.shape[1] - 1
-
- input_ids = input_ids[:, remove_prefix_length:]
- if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
-
- attention_mask = kwargs.get("attention_mask", None)
- position_ids = kwargs.get("position_ids", None)
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
- else:
- position_ids = None
- if conditioning_embeds is not None and past_key_values is not None:
- position_ids = ms.tensor([input_ids_length], dtype=ms.int32)
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
-
- model_inputs.update(
- {
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "position_ids": position_ids,
- "token_type_ids": token_type_ids,
- }
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
)
+ if conditioning_embeds is not None and cache_position[0] != 0:
+ model_inputs["position_ids"] = ms.Tensor([input_ids_length], dtype=ms.int64)
+
return model_inputs
- @add_start_docstrings_to_model_forward(CLVP_DECODER_INPUTS_DOCSTRING)
+ @auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None,
+ past_key_values: Optional[tuple[tuple[ms.Tensor]]] = None,
attention_mask: Optional[ms.Tensor] = None,
token_type_ids: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
@@ -1395,7 +1329,8 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -1422,6 +1357,7 @@ def construct(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
hidden_states = outputs[0]
@@ -1451,27 +1387,14 @@ def construct(
cross_attentions=outputs.cross_attentions,
)
- @staticmethod
- # Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
- def _reorder_cache(past_key_values: Tuple[Tuple[ms.Tensor]], beam_idx: ms.Tensor) -> Tuple[Tuple[ms.Tensor]]:
- """
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
- beam_idx at every generation step.
- """
- return tuple(
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) for layer_past in past_key_values
- )
-
-@add_start_docstrings(
- "The composite CLVP model with a text encoder, speech encoder and speech decoder model."
- "The speech decoder model generates the speech_ids from the text and the text encoder and speech encoder works"
- "together to filter out the best speech_ids.",
- CLVP_START_DOCSTRING,
+@auto_docstring(
+ custom_intro="""
+ The composite CLVP model with a text encoder, speech encoder and speech decoder model.
+ """
)
class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin):
- config_class = ClvpConfig
+ config: ClvpConfig
def __init__(self, config: ClvpConfig):
super().__init__(config)
@@ -1501,7 +1424,7 @@ def __init__(self, config: ClvpConfig):
self.text_encoder_model = ClvpEncoder(config.text_config)
self.speech_encoder_model = ClvpEncoder(config.speech_config)
- self.logit_scale = ms.Parameter(ms.tensor(self.config.logit_scale_init_value))
+ self.logit_scale = Parameter(ms.Tensor(self.config.logit_scale_init_value))
# Initialize weights and apply final processing
self.post_init()
@@ -1532,7 +1455,7 @@ def fix_speech_decoder_output(self, speech_ids: ms.Tensor) -> ms.Tensor:
stm = each_seq_stop_token_index.argmax()
speech_ids[i, stm:] = decoder_fixing_codes[0]
if stm - 3 < speech_ids.shape[1]:
- speech_ids[i, -3:] = ms.tensor([decoder_fixing_codes[1:]], dtype=ms.int32)
+ speech_ids[i, -3:] = ms.tensor([decoder_fixing_codes[1:]], dtype=ms.int64)
return speech_ids
@@ -1570,8 +1493,8 @@ def get_text_features(
Examples:
```python
- >>> from mindone.transformers import ClvpProcessor, ClvpModelForConditionalGeneration
>>> import mindspore as ms
+ >>> from mindone.transformers import ClvpProcessor, ClvpModelForConditionalGeneration
>>> # Define the Text
>>> text = "This is an example text."
@@ -1582,7 +1505,7 @@ def get_text_features(
>>> # Generate processor output and text embeds
>>> processor_output = processor(text=text, return_tensors="np")
- >>> text_embeds = model.get_text_features(input_ids=ms.Tensor(processor_output["input_ids"]))
+ >>> text_embeds = model.get_text_features(input_ids=ms.tensor(processor_output["input_ids"]))
```
"""
@@ -1616,9 +1539,6 @@ def get_speech_features(
input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Input text Tokens. Processed from the [`ClvpTokenizer`]. If speech_ids is not provided, then input_ids
and input_features will be used.
- input_features (`ms.Tensor` of shape `(batch_size, feature_size, time_dim)`, *optional*):
- Indicates log-melspectrogram representations for audio returned by [`ClvpFeatureExtractor`]. If
- speech_ids is not provided, then input_ids and input_features will be used.
conditioning_encoder_inputs_embeds (`ms.Tensor`, *optional*):
inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1640,23 +1560,24 @@ def get_speech_features(
```python
>>> import datasets
- >>> from mindone.transformers import ClvpProcessor, ClvpModelForConditionalGeneration
>>> import mindspore as ms
+ >>> from mindone.transformers import ClvpProcessor, ClvpModelForConditionalGeneration
>>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
>>> text = "This is an example text."
>>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
- >>> _, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
+ >>> audio = ds.sort("id")["audio"][0]
+ >>> audio_sample, sr = audio["array"], audio["sampling_rate"]
>>> # Define processor and model
>>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
>>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
>>> # Generate processor output and model output
- >>> processor_output = processor(raw_speech=audio, sampling_rate=sr, text=text, return_tensors="np")
+ >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="np")
>>> speech_embeds = model.get_speech_features(
- ... input_ids=ms.Tensor(processor_output["input_ids"]), input_features=ms.Tensor(processor_output["input_features"])
+ ... input_ids=ms.tensor(processor_output["input_ids"]), input_features=ms.tensor(processor_output["input_features"])
... )
```
"""
@@ -1692,12 +1613,11 @@ def get_speech_features(
return outputs[0]
- @add_start_docstrings_to_model_forward(CLVP_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=ClvpOutput, config_class=ClvpConfig)
+ @auto_docstring
def construct(
self,
- input_ids: ms.Tensor = None,
- input_features: ms.Tensor = None,
+ input_ids: Optional[ms.Tensor] = None,
+ input_features: Optional[ms.Tensor] = None,
conditioning_encoder_inputs_embeds: Optional[ms.Tensor] = None,
text_encoder_inputs_embeds: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
@@ -1705,35 +1625,39 @@ def construct(
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = False,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, ClvpOutput]:
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[tuple, ClvpOutput]:
r"""
- Returns:
+ conditioning_encoder_inputs_embeds (`ms.Tensor`, *optional*):
+ inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
+ text_encoder_inputs_embeds (`ms.Tensor`, *optional*):
+ inputs_embeds for the text encoder model passed in place of `input_ids`.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
Examples:
```python
>>> import datasets
- >>> from mindone.transformers import ClvpProcessor, ClvpModelForConditionalGeneration
- >>> import mindspore as ms
+ >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
>>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
>>> text = "This is an example text."
>>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
- >>> _, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
+ >>> audio = ds.sort("id")["audio"][0]
+ >>> audio_sample, sr = audio["array"], audio["sampling_rate"]
>>> # Define processor and model
>>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
>>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
>>> # processor outputs and model outputs
- >>> inputs = processor(raw_speech=audio, sampling_rate=sr, text=text, return_tensors="np")
- >>> for k, v in inputs.items():
- ... inputs[k] = ms.Tensor(v)
+ >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt")
>>> outputs = model(
- ... input_ids=inputs["input_ids"],
- ... input_features=inputs["input_features"],
+ ... input_ids=processor_output["input_ids"],
+ ... input_features=processor_output["input_features"],
... return_dict=True,
... )
```
@@ -1756,12 +1680,13 @@ def construct(
inputs_embeds=conditioning_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
speech_ids = decoder_outputs[0]
- # since we will get the embeds of shape `(batch_size, seq_len, embedding_dim)` during the construct pass
- # we must convert it to tokens, to make it compatable with speech_transformer
+ # since we will get the embeds of shape `(batch_size, seq_len, embedding_dim)` during the forward pass
+ # we must convert it to tokens, to make it compaitable with speech_transformer
if speech_ids.ndim == 3:
speech_ids = speech_ids.argmax(2)
speech_ids = self.fix_speech_decoder_output(speech_ids)
@@ -1827,10 +1752,11 @@ def construct(
speech_encoder_hidden_states=speech_outputs.hidden_states,
)
+ @ms._no_grad()
def generate(
self,
- input_ids: ms.Tensor = None,
- input_features: ms.Tensor = None,
+ input_ids: Optional[ms.Tensor] = None,
+ input_features: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
pad_to_max_mel_tokens: Optional[int] = None,
@@ -1845,8 +1771,6 @@ def generate(
Args:
input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Input text Tokens. Processed from the [`ClvpTokenizer`].
- input_features (`ms.Tensor` of shape `(batch_size, feature_size, time_dim)`, *optional*):
- Indicates log-melspectrogram representations for audio returned by [`ClvpFeatureExtractor`].
attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding text token indices. Mask values selected in `[0, 1]`:
diff --git a/mindone/transformers/models/cohere/modeling_cohere.py b/mindone/transformers/models/cohere/modeling_cohere.py
index 2668958469..58be395a06 100644
--- a/mindone/transformers/models/cohere/modeling_cohere.py
+++ b/mindone/transformers/models/cohere/modeling_cohere.py
@@ -1,7 +1,7 @@
from typing import Callable, List, Optional, Tuple, Union
from transformers import CohereConfig
-from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from transformers.utils.deprecation import deprecate_kwarg
import mindspore as ms
@@ -567,7 +567,7 @@ def _update_causal_mask(
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
- target_length = past_key_values.get_max_length()
+ target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
@@ -621,10 +621,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
diff --git a/mindone/transformers/models/cohere2/modeling_cohere2.py b/mindone/transformers/models/cohere2/modeling_cohere2.py
index c63a6cdd57..8ec591c7a4 100644
--- a/mindone/transformers/models/cohere2/modeling_cohere2.py
+++ b/mindone/transformers/models/cohere2/modeling_cohere2.py
@@ -26,7 +26,6 @@
from transformers import Cohere2Config
from transformers.utils import (
- LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@@ -49,6 +48,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
logger = logging.get_logger(__name__)
@@ -779,10 +779,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -834,7 +830,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[mint.Tensor] = None,
logits_to_keep: Union[int, mint.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`mint.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/mindone/transformers/models/colpali/modeling_colpali.py b/mindone/transformers/models/colpali/modeling_colpali.py
index e1251f7606..f764f67882 100644
--- a/mindone/transformers/models/colpali/modeling_colpali.py
+++ b/mindone/transformers/models/colpali/modeling_colpali.py
@@ -173,6 +173,13 @@ class ColPaliForRetrievalOutput(ModelOutput):
"""
)
class ColPaliForRetrieval(ColPaliPreTrainedModel):
+ _checkpoint_conversion_mapping = {
+ "vlm.language_model.model": "vlm.model.language_model",
+ "vlm.vision_tower": "vlm.model.vision_tower",
+ "vlm.multi_modal_projector": "vlm.model.multi_modal_projector",
+ "vlm.language_model.lm_head": "vlm.lm_head",
+ }
+
def __init__(self, config: ColPaliConfig):
super().__init__(config)
self.config = config
@@ -272,60 +279,55 @@ def construct(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.vlm(
+ vlm_output = self.vlm.model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
output_hidden_states=True,
- return_dict=return_dict,
+ return_dict=True,
output_attentions=output_attentions,
**kwargs,
)
+ vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
+ vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
- last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
+ last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
- embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
-
- loss = None
- if not return_dict:
- output = (embeddings,) + outputs[2:]
- output[2] = output[2] if output_hidden_states is not None else None
- output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,)
- return (loss,) + output if loss is not None else output
+ if attention_mask is not None:
+ embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return ColPaliForRetrievalOutput(
- loss=loss,
embeddings=embeddings,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states if pixel_values is not None else None,
+ past_key_values=vlm_output.past_key_values,
+ hidden_states=vlm_hidden_states,
+ attentions=vlm_output.attentions,
+ image_hidden_states=vlm_image_hidden_states,
)
def get_input_embeddings(self):
- return self.vlm.language_model.get_input_embeddings()
+ return self.vlm.get_input_embeddings()
def set_input_embeddings(self, value):
- self.vlm.language_model.set_input_embeddings(value)
+ self.vlm.set_input_embeddings(value)
def get_output_embeddings(self):
- return self.vlm.language_model.get_output_embeddings()
+ return self.vlm.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
- self.vlm.language_model.set_output_embeddings(new_embeddings)
+ self.vlm.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
- self.vlm.language_model.set_decoder(decoder)
+ self.vlm.set_decoder(decoder)
def get_decoder(self):
- return self.vlm.language_model.get_decoder()
+ return self.vlm.get_decoder()
def tie_weights(self):
- return self.vlm.language_model.tie_weights()
+ return self.vlm.tie_weights()
def resize_token_embeddings(
self,
@@ -333,7 +335,7 @@ def resize_token_embeddings(
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> mint.nn.Embedding:
- model_embeds = self.vlm.language_model.resize_token_embeddings(
+ model_embeds = self.vlm.resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
diff --git a/mindone/transformers/models/emu3/modeling_emu3.py b/mindone/transformers/models/emu3/modeling_emu3.py
index 3f3e580db0..a7faf5e5a8 100644
--- a/mindone/transformers/models/emu3/modeling_emu3.py
+++ b/mindone/transformers/models/emu3/modeling_emu3.py
@@ -28,7 +28,6 @@
from transformers import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
from transformers.utils import (
- LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@@ -52,6 +51,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
@jit_class
@@ -1616,10 +1616,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -1670,7 +1666,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1832,13 +1828,12 @@ def construct(
"""
-class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["text_model.lm_head.weight"]
- _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
+class Emu3Model(Emu3PreTrainedModel):
+ _checkpoint_conversion_mapping = {"text_model.model": "text_model"}
def __init__(self, config):
super().__init__(config)
- self.text_model = Emu3ForCausalLM._from_config(config.text_config)
+ self.text_model = Emu3TextModel._from_config(config.text_config)
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@@ -1887,6 +1882,102 @@ def decode_image_tokens(self, image_tokens: ms.Tensor, height: int, width: int):
image = self.vqmodel.decode(image_tokens)
return image
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ pixel_values: ms.Tensor = None,
+ image_sizes: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ image_sizes (`ms.Tensor` of shape `(batch_size, 2)`):
+ The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
+ [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
+ [`Emu3ImageProcessor`] for processing images).
+ """
+ if (input_ids is None) != (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None:
+ image_tokens = self.get_image_tokens(pixel_values, image_sizes)
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
+ image_tokens = image_tokens.to(input_ids.dtype)
+ input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ return outputs
+
+
+class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
+ base_model_prefix = ""
+ _tied_weights_keys = ["lm_head.weight"]
+ _checkpoint_conversion_mapping = {
+ "^text_model.model": "model.text_model",
+ "^vqmodel": "model.vqmodel",
+ "^text_model.lm_head": "lm_head",
+ }
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Emu3Model(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Cell:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ # Make modules available throught conditional class for BC
+ @property
+ def text_model(self):
+ return self.model.text_model
+
+ @property
+ def vqmodel(self):
+ return self.model.vqmodel
+
+ @property
+ def vocabulary_mapping(self):
+ return self.model.vocabulary_mapping
+
+ def decode_image_tokens(self, **kwargs):
+ return self.model.decode_image_tokens(**kwargs)
+
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def construct(
@@ -1905,6 +1996,7 @@ def construct(
cache_position: Optional[ms.Tensor] = None,
labels: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1966,44 +2058,35 @@ def construct(
>>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
>>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) != (inputs_embeds is not None):
- raise ValueError(
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
- )
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if pixel_values is not None:
- image_tokens = self.get_image_tokens(pixel_values, image_sizes)
- special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
- image_tokens = image_tokens.to(input_ids.dtype)
- input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.text_model(
+ outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
+ **kwargs,
)
- return outputs
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
def prepare_inputs_for_generation(
self,
diff --git a/mindone/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/mindone/transformers/models/falcon_mamba/modeling_falcon_mamba.py
index d8cf78d1c3..5146333a3b 100644
--- a/mindone/transformers/models/falcon_mamba/modeling_falcon_mamba.py
+++ b/mindone/transformers/models/falcon_mamba/modeling_falcon_mamba.py
@@ -98,7 +98,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
- self.use_mambapy = config.use_mambapy
+ self.use_mambapy = config.use_falcon_mambapy
# projection of the input hidden states
self.in_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
diff --git a/mindone/transformers/models/fuyu/modeling_fuyu.py b/mindone/transformers/models/fuyu/modeling_fuyu.py
index 5acb65fd0a..49f1660c75 100644
--- a/mindone/transformers/models/fuyu/modeling_fuyu.py
+++ b/mindone/transformers/models/fuyu/modeling_fuyu.py
@@ -26,10 +26,11 @@
from mindspore import mint
from mindspore.common.initializer import Normal, initializer
+from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
-from ...models.auto.modeling_auto import AutoModelForCausalLM
+from ...models.auto.modeling_auto import AutoModel
logger = logging.get_logger(__name__)
@@ -125,12 +126,14 @@ def _init_weights(self, module):
"""
-class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
+class FuyuModel(FuyuPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
@@ -202,6 +205,110 @@ def gather_continuous_embeddings(
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
return output_embeddings
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ image_patches: ms.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
+ image_patches_indices: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[tuple, CausalLMOutputWithPast]:
+ r"""
+ image_patches (`ms.Tensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
+ Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
+ hidden size of the model.
+ image_patches_indices (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Tensor of indices of the image patches in the input_ids tensor.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_is or inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ position_ids = mint.arange(past_key_values_length, seq_length + past_key_values_length, dtype=ms.int32)
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+ if image_patches is not None and past_key_values is None:
+ patch_embeddings = [
+ self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
+ for patch in image_patches
+ ]
+ inputs_embeds = self.gather_continuous_embeddings(
+ word_embeddings=inputs_embeds,
+ continuous_embeddings=patch_embeddings,
+ image_patch_input_indices=image_patches_indices,
+ )
+
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ return_dict=return_dict,
+ **kwargs,
+ )
+
+ return outputs
+
+
+class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_embed_tokens": "model.vision_embed_tokens",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: FuyuConfig):
+ super().__init__(config)
+ self.model = FuyuModel(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
def construct(
self,
input_ids: ms.Tensor = None,
@@ -216,6 +323,7 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ logits_to_keep: Optional[int] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
@@ -255,7 +363,6 @@ def construct(
>>> print(generation_text[0])
A blue bus parked on the side of a road.
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -264,53 +371,39 @@ def construct(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError("You have to specify either input_is or inputs_embeds")
-
- seq_length_with_past = seq_length
- past_key_values_length = 0
-
- if past_key_values is not None:
- past_key_values_length = past_key_values[0][0].shape[2]
- seq_length_with_past = seq_length_with_past + past_key_values_length
-
- if position_ids is None:
- position_ids = mint.arange(past_key_values_length, seq_length + past_key_values_length, dtype=ms.int32)
- position_ids = position_ids.unsqueeze(0)
-
- if inputs_embeds is None:
- inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
- if image_patches is not None and past_key_values is None:
- patch_embeddings = [
- self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
- for patch in image_patches
- ]
- inputs_embeds = self.gather_continuous_embeddings(
- word_embeddings=inputs_embeds,
- continuous_embeddings=patch_embeddings,
- image_patch_input_indices=image_patches_indices,
- )
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ image_patches=image_patches,
+ image_patches_indices=image_patches_indices,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- labels=labels,
use_cache=use_cache,
- return_dict=return_dict,
- **kwargs,
+ return_dict=True,
+ # don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
)
- return outputs
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
def prepare_inputs_for_generation(
self,
diff --git a/mindone/transformers/models/gemma/modeling_gemma.py b/mindone/transformers/models/gemma/modeling_gemma.py
index 7b9bf90ccb..b6a71a91e4 100644
--- a/mindone/transformers/models/gemma/modeling_gemma.py
+++ b/mindone/transformers/models/gemma/modeling_gemma.py
@@ -25,7 +25,6 @@
from typing import Callable, List, Optional, Tuple, Union
from transformers.models.gemma.configuration_gemma import GemmaConfig
-from transformers.utils import LossKwargs, logging
import mindspore as ms
from mindspore import mint, nn, ops
@@ -45,6 +44,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
logger = logging.get_logger(__name__)
@@ -704,10 +704,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -754,7 +750,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/mindone/transformers/models/gemma3/modeling_gemma3.py b/mindone/transformers/models/gemma3/modeling_gemma3.py
index 5add930901..73052312b7 100644
--- a/mindone/transformers/models/gemma3/modeling_gemma3.py
+++ b/mindone/transformers/models/gemma3/modeling_gemma3.py
@@ -44,12 +44,29 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Gemma3Config"
+@dataclass
+class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`ms.Tensor`, *optional*):
+ A `ms.Tensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[ms.Tensor] = None
+
+
@dataclass
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
@@ -1074,17 +1091,18 @@ def construct(self, vision_outputs: ms.Tensor):
return projected_vision_outputs.type_as(vision_outputs)
-class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
+class Gemma3Model(Gemma3PreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
+ accepts_loss_kwargs = False
+
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
- language_model = AutoModelForCausalLM.from_config(config=config.text_config)
-
- if language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
+ language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
@@ -1193,7 +1211,7 @@ def construct(
pixel_values: ms.Tensor = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
- past_key_values: Optional[Union[List[ms.Tensor], Cache]] = None,
+ past_key_values: Optional[Union[list[ms.Tensor], Cache]] = None,
token_type_ids: Optional[ms.Tensor] = None,
cache_position: Optional[ms.Tensor] = None,
inputs_embeds: Optional[ms.Tensor] = None,
@@ -1202,53 +1220,35 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- logits_to_keep: Union[int, ms.Tensor] = 0,
**lm_kwargs,
- ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
+ ) -> Union[tuple, Gemma3ModelOutputWithPast]:
r"""
- labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
-
- logits_to_keep (`int` or `ms.Tensor`, *optional*):
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
- If a `ms.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
- This is useful when using packed tensor format (single dimension for batch and sequence length).
-
- Returns:
+ labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
- >>> from transformers import AutoProcessor
- >>> from mindone.transformers import Gemma3ForConditionalGeneration
- >>> import numpy as np
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
- >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
- >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma32-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma32-3b-mix-224")
- >>> prompt = "answer en Where is the cow standing?"
- >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
- >>> inputs = processor(images=image, text=prompt, return_tensors="np")
- >>> for key, value in inputs.items():
- >>> if isinstance(value, np.ndarray):
- >>> inputs[key] = ms.tensor(value)
- >>> elif isinstance(value, list):
- >>> inputs[key] = ms.tensor(value)
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
- >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "answer en Where is the cow standing?\nbeach"
+ "Where is the cat standing?\nsnow"
```"""
-
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -1319,11 +1319,150 @@ def construct(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
**lm_kwargs,
)
- logits = outputs[0]
+ return Gemma3ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: Gemma3Config):
+ super().__init__(config)
+ self.model = Gemma3Model(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(self, pixel_values):
+ return self.model.get_image_features(pixel_values)
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ pixel_values: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Union[List[ms.Tensor], Cache]] = None,
+ token_type_ids: Optional[ms.Tensor] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, ms.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
+ r"""
+ labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ logits_to_keep (`int` or `ms.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `ms.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor
+ >>> from mindone.transformers import Gemma3ForConditionalGeneration
+ >>> import numpy as np
+
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
+ >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
+
+ >>> prompt = "answer en Where is the cow standing?"
+ >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="np")
+ >>> for key, value in inputs.items():
+ >>> if isinstance(value, np.ndarray):
+ >>> inputs[key] = ms.tensor(value)
+ >>> elif isinstance(value, list):
+ >>> inputs[key] = ms.tensor(value)
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "answer en Where is the cow standing?\nbeach"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ labels=labels,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
@@ -1345,6 +1484,7 @@ def construct(
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1)
loss = loss_fct(flat_logits, flat_labels)
+
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@@ -1355,7 +1495,7 @@ def construct(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
diff --git a/mindone/transformers/models/glm/modeling_glm.py b/mindone/transformers/models/glm/modeling_glm.py
index 1bc9200fbb..b2ac469f22 100644
--- a/mindone/transformers/models/glm/modeling_glm.py
+++ b/mindone/transformers/models/glm/modeling_glm.py
@@ -26,7 +26,7 @@
from typing import Callable, Optional, Tuple, Union
from transformers import GlmConfig
-from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
import mindspore as ms
from mindspore import mint, nn, ops
@@ -46,6 +46,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
class GlmRMSNorm(nn.Cell):
@@ -747,10 +748,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@@ -800,7 +797,7 @@ def construct(
return_dict: Optional[bool] = False,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
diff --git a/mindone/transformers/models/glm4/modeling_glm4.py b/mindone/transformers/models/glm4/modeling_glm4.py
index 455c15357a..e7f5578396 100644
--- a/mindone/transformers/models/glm4/modeling_glm4.py
+++ b/mindone/transformers/models/glm4/modeling_glm4.py
@@ -25,7 +25,6 @@
from typing import Callable, Optional, Tuple, Union
from transformers.models.glm4.configuration_glm4 import Glm4Config
-from transformers.utils import LossKwargs
import mindspore as ms
from mindspore import mint, nn, ops
@@ -44,7 +43,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
-from ...utils import logging
+from ...utils import TransformersKwargs, logging
logger = logging.get_logger(__name__)
@@ -289,10 +288,6 @@ def construct(
return attn_output, attn_weights
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class Glm4RMSNorm(nn.Cell):
def __init__(self, hidden_size, eps=1e-6):
"""
@@ -711,7 +706,7 @@ def construct(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/mindone/transformers/models/glm4v/modeling_glm4v.py b/mindone/transformers/models/glm4v/modeling_glm4v.py
index 90796d60d3..c2568b02c2 100644
--- a/mindone/transformers/models/glm4v/modeling_glm4v.py
+++ b/mindone/transformers/models/glm4v/modeling_glm4v.py
@@ -27,7 +27,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-from transformers.utils import LossKwargs, logging
+from transformers.utils import logging
import mindspore as ms
import mindspore.mint.nn.functional as F
@@ -45,6 +45,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig
logger = logging.get_logger(__name__)
@@ -509,10 +510,6 @@ def construct(self, hidden_states: ms.Tensor, grid_thw: ms.Tensor) -> ms.Tensor:
return hidden_states
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
@dataclass
class Glm4vModelOutputWithPast(ModelOutput):
"""
@@ -1222,7 +1219,7 @@ def construct(
video_grid_thw: Optional[ms.Tensor] = None,
rope_deltas: Optional[ms.Tensor] = None,
cache_position: Optional[ms.Tensor] = None,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, Glm4vModelOutputWithPast]:
r"""
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
@@ -1463,7 +1460,7 @@ def construct(
video_grid_thw: Optional[ms.Tensor] = None,
rope_deltas: Optional[ms.Tensor] = None,
cache_position: Optional[ms.Tensor] = None,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, Glm4vCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/mindone/transformers/models/glm4v/processing_glm4v.py b/mindone/transformers/models/glm4v/processing_glm4v.py
index f541e7c601..9c38dea148 100644
--- a/mindone/transformers/models/glm4v/processing_glm4v.py
+++ b/mindone/transformers/models/glm4v/processing_glm4v.py
@@ -212,7 +212,6 @@ def __call__(
text[i] = text[i].replace("<|placeholder|>", self.image_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
- output_kwargs["text_kwargs"].pop("videos")
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
diff --git a/mindone/transformers/models/got_ocr2/modeling_got_ocr2.py b/mindone/transformers/models/got_ocr2/modeling_got_ocr2.py
index 60ef5cae34..257680b497 100644
--- a/mindone/transformers/models/got_ocr2/modeling_got_ocr2.py
+++ b/mindone/transformers/models/got_ocr2/modeling_got_ocr2.py
@@ -28,13 +28,16 @@
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
import mindspore as ms
-from mindspore import mint
+from mindspore import mint, nn
from ...activations import ACT2FN
+from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...modeling_outputs import ModelOutput
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
-from ..auto import AutoModelForCausalLM
+from ...processing_utils import Unpack
+from ..auto import AutoModel
_CONFIG_FOR_DOC = "GotOcr2Config"
@@ -683,24 +686,32 @@ def _init_weights(self, module):
"""
-@add_start_docstrings(
- """The GOT_OCR2 model which consists of a vision backbone and a language model.""",
- GOT_OCR2_START_DOCSTRING,
-)
-class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
+@dataclass
+class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`ms.Tensor`, *optional*):
+ A `ms.Tensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[ms.Tensor] = None
+
+
+class GotOcr2Model(GotOcr2PreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
self.multi_modal_projector = GotOcr2MultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- if self.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
-
- self.pad_token_id = config.pad_token_id
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@@ -736,6 +747,139 @@ def get_image_features(
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ pixel_values: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, GotOcr2ModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds)
+ image_features = image_features.to(inputs_embeds.dtype)
+ if inputs_embeds.dtype == ms.bfloat16:
+ inputs_embeds_fp = inputs_embeds.to(ms.float32)
+ image_features_fp = image_features.to(ms.float32)
+ replaced = inputs_embeds_fp.masked_scatter(special_image_mask, image_features_fp)
+ inputs_embeds = replaced.to(inputs_embeds.dtype)
+ else:
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ return GotOcr2ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@add_start_docstrings(
+ """The GOT_OCR2 model which consists of a vision backbone and a language model.""",
+ GOT_OCR2_START_DOCSTRING,
+)
+class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: GotOcr2Config):
+ super().__init__(config)
+ self.model = GotOcr2Model(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Cell:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder
+
+ def get_image_features(
+ self,
+ pixel_values: ms.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ **kwargs,
+ )
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def construct(
@@ -753,6 +897,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
+ **kwargs,
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -811,44 +956,15 @@ def construct(
"You should keep in mind what features from the module should be used, especially
when you're planning to sell a template."
```"""
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if pixel_values is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
- n_image_tokens = (input_ids == self.config.image_token_index).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- if n_image_tokens != n_image_features:
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds)
- image_features = image_features.to(inputs_embeds.dtype)
- if inputs_embeds.dtype == ms.bfloat16:
- inputs_embeds_fp = inputs_embeds.to(ms.float32)
- image_features_fp = image_features.to(ms.float32)
- replaced = inputs_embeds_fp.masked_scatter(special_image_mask, image_features_fp)
- inputs_embeds = replaced.to(inputs_embeds.dtype)
- else:
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -856,32 +972,22 @@ def construct(
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
+ **kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :]
- shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = mint.nn.CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
return GotOcr2CausalLMOutputWithPast(
loss=loss,
@@ -889,7 +995,7 @@ def construct(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
diff --git a/mindone/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/mindone/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index a3893ef38b..1ee49fea25 100644
--- a/mindone/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/mindone/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -1,5 +1,9 @@
# coding=utf-8
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
+#
+# This code is adapted from https://github.com/huggingface/transformers
+# with modifications to run transformers on mindspore.
+#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@@ -11,21 +15,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""PyTorch GPTBigCode model."""
+"""MindSpore GPTBigCode model."""
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, Optional, Union
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import GPTBigCodeConfig
+from transformers.utils import auto_docstring
import mindspore as ms
-from mindspore import Parameter, mint, nn
-from mindspore.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from mindspore import mint, nn
+from mindspore.mint.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
+from ...cache_utils import Cache, EncoderDecoderCache
from ...generation import GenerationMixin
-from ...mindspore_adapter import dtype_to_min, scaled_dot_product_attention
-from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
@@ -33,11 +38,11 @@
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...utils import can_return_tuple, logging
if is_flash_attn_available():
- from mindspore.ops.operations.nn_ops import FlashAttentionScore as MSFlashAttention
+ pass
logger = logging.get_logger(__name__)
@@ -47,27 +52,70 @@
# Use separate functions for each case because conditionals prevent kernel fusion.
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
# Is it doable without writing 32 functions?
-def upcast_masked_softmax(x: ms.Tensor, mask: ms.Tensor, mask_value: ms.Tensor, scale: float, softmax_dtype: ms.dtype):
+def upcast_masked_softmax(x: ms.Tensor, mask: ms.Tensor, mask_value: ms.Tensor, scale: float, softmax_dtype: ms.Type):
input_dtype = x.dtype
x = x.to(softmax_dtype) * scale
- x = mint.where(mask.to(ms.bool_), x, mask_value)
- x = mint.functional.softmax(x, dim=-1).to(input_dtype)
+ x = mint.where(mask, x, mask_value)
+ x = mint.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
-def upcast_softmax(x: ms.Tensor, scale: float, softmax_dtype: ms.dtype):
+def upcast_softmax(x: ms.Tensor, scale: float, softmax_dtype: ms.Type):
input_dtype = x.dtype
x = x.to(softmax_dtype) * scale
- x = mint.functional.softmax(x, dim=-1).to(input_dtype)
+ x = mint.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
def masked_softmax(x: ms.Tensor, mask: ms.Tensor, mask_value: ms.Tensor):
- x = mint.where(mask.to(ms.bool_), x, mask_value)
- x = mint.functional.softmax(x, dim=-1)
+ x = mint.where(mask, x, mask_value)
+ x = mint.nn.functional.softmax(x, dim=-1)
return x
+def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim))
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Cell,
+ query: ms.Tensor,
+ key: ms.Tensor,
+ value: ms.Tensor,
+ attention_mask: Optional[ms.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ head_mask: Optional[ms.Tensor] = None,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = mint.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype)
+ attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_output = mint.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class GPTBigCodeAttention(nn.Cell):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
@@ -80,6 +128,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.head_dim = self.embed_dim // self.num_heads
self.kv_heads = 1 if self.multi_query else self.num_heads
self.kv_dim = self.kv_heads * self.head_dim
+ self.num_key_value_groups = self.num_heads // self.kv_heads
self.split_size = self.embed_dim
self.is_causal = True
@@ -90,6 +139,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
)
self.scale_attn_weights = config.scale_attn_weights
+ self.scaling = self.head_dim**0.5 if config.scale_attn_weights else 1.0
self.is_cross_attention = is_cross_attention
self.layer_idx = layer_idx
@@ -110,366 +160,90 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.c_proj = mint.nn.Linear(self.embed_dim, self.embed_dim)
- self.attn_dropout = mint.nn.Dropout(config.attn_pdrop)
- self.resid_dropout = mint.nn.Dropout(config.resid_pdrop)
-
- def _get_mask_value(self, dtype):
- # mint.where expects a tensor. We use a cache to avoid recreating it every time.
- if self.mask_value is None or self.mask_value.dtype != dtype:
- self.mask_value = mint.full([], dtype_to_min(dtype).item(), dtype=dtype)
- return self.mask_value
-
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
- dtype = query.dtype
- softmax_dtype = ms.float32 if self.attention_softmax_in_fp32 else dtype
- upcast = dtype != softmax_dtype
-
- unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
- scale_factor = unscale**-1
- if self.scale_attn_weights:
- scale_factor /= self.head_dim**0.5
-
- # MQA models: (batch_size, query_length, num_heads * head_dim)
- # MHA models: (batch_size, num_heads, query_length, head_dim)
- query_shape = query.shape
- batch_size = query_shape[0]
- key_length = key.shape[-1]
- if self.multi_query:
- # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
- # -> (batch_size, query_length, num_heads, key_length)
- query_length = query_shape[1]
- attn_shape = (batch_size, query_length, self.num_heads, key_length)
- attn_view = (batch_size, query_length * self.num_heads, key_length)
- # No copy needed for MQA 2, or when layer_past is provided.
- query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
- else:
- # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
- # -> (batch_size, num_heads, query_length, key_length)
- query_length = query_shape[2]
- attn_shape = (batch_size, self.num_heads, query_length, key_length)
- attn_view = (batch_size * self.num_heads, query_length, key_length)
- # Always copies
- query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
- # No copy when layer_past is provided.
- key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
-
- attn_weights = mint.empty(attn_view, dtype=query.dtype)
- beta = 0
- attn_weights = mint.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)
-
- if upcast:
- # Use a fused kernel to prevent a large overhead from casting and scaling.
- # Sub-optimal when the key length is not a multiple of 8.
- if attention_mask is None:
- attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
- else:
- mask_value = self._get_mask_value(softmax_dtype)
- attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
- else:
- if attention_mask is not None:
- mask_value = self._get_mask_value(softmax_dtype)
-
- # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
- attn_weights = mint.where(attention_mask.to(ms.bool_), attn_weights, mask_value)
-
- attn_weights = mint.functional.softmax(attn_weights, dim=-1)
-
- attn_weights = self.attn_dropout(attn_weights)
-
- # Mask heads if we want to
- if head_mask is not None:
- if self.multi_query:
- head_mask = mint.transpose(head_mask, 1, 2)
- attn_weights = attn_weights * head_mask
-
- if self.multi_query:
- attn_output = mint.bmm(attn_weights.view(attn_view), value).view(query_shape)
- else:
- attn_output = mint.matmul(attn_weights, value)
-
- return attn_output, attn_weights
+ self.attn_dropout = config.attn_pdrop
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
def construct(
self,
hidden_states: ms.Tensor,
- layer_past: Optional[ms.Tensor] = None,
+ layer_past: Optional[Cache] = None,
attention_mask: Optional[ms.Tensor] = None,
head_mask: Optional[ms.Tensor] = None,
encoder_hidden_states: Optional[ms.Tensor] = None,
encoder_attention_mask: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
- ) -> Union[Tuple[ms.Tensor, Optional[ms.Tensor]], Tuple[ms.Tensor, Optional[ms.Tensor], Tuple[ms.Tensor, ...]]]:
- if encoder_hidden_states is not None:
- if not hasattr(self, "q_attn") or not self.is_cross_attention:
- raise ValueError(
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
- "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
- )
-
- query = self.q_attn(hidden_states)
- key_value = self.c_attn(encoder_hidden_states)
- attention_mask = encoder_attention_mask
- elif self.multi_query:
- query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
- else:
- # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
- # i.e., the memory layout is not the same as GPT2.
- # This makes the concatenation with past_key_value more efficient.
- query, key_value = mint.transpose(
- self.c_attn(hidden_states).view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim), 1, 2
- ).split((self.head_dim, 2 * self.head_dim), dim=3)
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple[ms.Tensor, Optional[ms.Tensor]], tuple[ms.Tensor, Optional[ms.Tensor], tuple[ms.Tensor, ...]]]:
+ input_shape = hidden_states.shape[:-1]
if layer_past is not None:
- key_value = mint.cat((layer_past, key_value), dim=-2)
- present = key_value if use_cache else None
-
- key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
-
- attn_output, attn_weights = self._attn(query, mint.transpose(key, -1, -2), value, attention_mask, head_mask)
-
- if not self.multi_query:
- attn_output = mint.transpose(attn_output, 1, 2).reshape(hidden_states.shape)
- attn_output = self.c_proj(attn_output)
- attn_output = self.resid_dropout(attn_output)
-
- outputs = (attn_output, present)
- if output_attentions:
- if self.multi_query:
- # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
- attn_weights = mint.transpose(attn_weights, 1, 2)
- outputs += (attn_weights,)
-
- return outputs # a, present, (attentions)
-
-
-class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
- """
- GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
- stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
- API of flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- attn_dropout = self.attn_pdrop if self.training else 0.0
- self.flash_attention = MSFlashAttention(
- head_num=self.num_heads,
- keep_prob=1 - attn_dropout,
- input_layout="BSND",
- )
+ if isinstance(layer_past, EncoderDecoderCache):
+ is_updated = layer_past.is_updated.get(self.layer_idx)
+ if self.is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = layer_past.cross_attention_cache
+ else:
+ curr_past_key_value = layer_past.self_attention_cache
+ else:
+ curr_past_key_value = layer_past
- def construct(
- self,
- hidden_states: ms.Tensor,
- layer_past: Optional[ms.Tensor] = None,
- attention_mask: Optional[ms.Tensor] = None,
- head_mask: Optional[ms.Tensor] = None,
- encoder_hidden_states: Optional[ms.Tensor] = None,
- encoder_attention_mask: Optional[ms.Tensor] = None,
- use_cache: Optional[bool] = False,
- output_attentions: Optional[bool] = False,
- ) -> Union[Tuple[ms.Tensor, Optional[ms.Tensor]], Tuple[ms.Tensor, Optional[ms.Tensor], Tuple[ms.Tensor, ...]]]:
- if encoder_hidden_states is not None:
+ if self.is_cross_attention:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
)
-
- query = self.q_attn(hidden_states)
- key_value = self.c_attn(encoder_hidden_states)
- elif self.multi_query:
- query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
- else:
- # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
- # i.e., the memory layout is not the same as GPT2.
- # This makes the concatenation with past_key_value more efficient.
- query, key_value = mint.transpose(
- self.c_attn(hidden_states).view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim), 1, 2
- ).split((self.head_dim, 2 * self.head_dim), dim=3)
-
- if layer_past is not None:
- key_value = mint.cat((layer_past, key_value), dim=-2)
- present = key_value if use_cache else None
-
- key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
-
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- if self.multi_query:
- batch_size, query_length, _ = query.shape
- query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
- key = key.unsqueeze(2)
- value = value.unsqueeze(2)
- else:
- query_length = query.shape[2]
- batch_size, _, tgt, _ = key.shape
- query = mint.transpose(query, 1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim)
- key = mint.transpose(key, 1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
- value = mint.transpose(value, 1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in float16 just to be sure everything works as expected.
- input_dtype = query.dtype
- if input_dtype == ms.float32:
- # Handle the case where the model is quantized
- if hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
+ if layer_past is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key = curr_past_key_value.key_cache[self.layer_idx]
+ value = curr_past_key_value.value_cache[self.layer_idx]
else:
- target_dtype = self.c_attn.weight.dtype
-
- logger.warning(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
- query = query.to(target_dtype)
- key = key.to(target_dtype)
- value = value.to(target_dtype)
-
- attn_output = self.flash_attention(query, key, value)[3]
-
- attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
- attn_output = self.c_proj(attn_weights_reshaped)
- attn_output = self.resid_dropout(attn_output)
-
- outputs = (attn_output, present)
-
- if output_attentions:
- if self.multi_query:
- # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
- attn_weights_reshaped = mint.transpose(attn_weights_reshaped, 1, 2)
- else:
- attn_weights_reshaped = None
-
- outputs += (attn_weights_reshaped,)
-
- return outputs # a, present, (attentions)
-
-
-class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
- # MQA models: (batch_size, query_length, num_heads * head_dim)
- # MHA models: (batch_size, num_heads, query_length, head_dim)
- query_shape = query.shape
- batch_size = query_shape[0]
- key.shape[-2]
-
- if self.multi_query:
- query_length = query_shape[1]
-
- # SDPA requires the dimension [..., sequence_length, head_dim].
- query = mint.transpose(query.view(batch_size, query_length, self.num_heads, self.head_dim), 1, 2)
-
- # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
- key = key.unsqueeze(1)
- value = value.unsqueeze(1)
-
- # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
- # and flash attention backend (No available kernel. Aborting execution.) from the shapes
- # query = [batch_size, num_heads, query_length, head_dim]
- # key = [batch_size, 1, past_length, head_dim]
- # value = [batch_size, 1, past_length, head_dim]
- key = mint.broadcast_to(key, (-1, self.num_heads, -1, -1))
- value = mint.broadcast_to(value, (-1, self.num_heads, -1, -1))
+ query = self.q_attn(hidden_states).view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ key, value = self.c_attn(encoder_hidden_states).split((self.head_dim, self.head_dim), dim=-1)
else:
- query_length = query_shape[-1]
-
- # See the comment above.
- if attention_mask is not None:
- query = query.contiguous()
- key = key.contiguous()
- value = value.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an
- # inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional
- # prevents dynamic shapes from compiling.
- # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
- # create a causal mask in case query_length == 1.
-
- sdpa_result = scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
-
- if self.multi_query:
- # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
- sdpa_result = mint.transpose(sdpa_result, 1, 2)
-
- # Reshape is kind of expensive here, as it does a memory copy,
- # but I did not manage to make away without it (logits do not match when using view)
- # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
- sdpa_result = sdpa_result.reshape(query_shape)
-
- return sdpa_result, None
-
- def construct(
- self,
- hidden_states: ms.Tensor,
- layer_past: Optional[ms.Tensor] = None,
- attention_mask: Optional[ms.Tensor] = None,
- head_mask: Optional[ms.Tensor] = None,
- encoder_hidden_states: Optional[ms.Tensor] = None,
- encoder_attention_mask: Optional[ms.Tensor] = None,
- use_cache: Optional[bool] = False,
- output_attentions: Optional[bool] = False,
- ) -> Union[Tuple[ms.Tensor, Optional[ms.Tensor]], Tuple[ms.Tensor, Optional[ms.Tensor], Tuple[ms.Tensor, ...]]]:
- if encoder_hidden_states is not None:
- if not hasattr(self, "q_attn") or not self.is_cross_attention:
- raise ValueError(
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
- "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
+ if self.multi_query:
+ query, key, value = (
+ self.c_attn(hidden_states).unsqueeze(1).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=3)
+ )
+ query = query.view(*input_shape, -1, self.head_dim).transpose(1, 2)
+ else:
+ query, key, value = (
+ self.c_attn(hidden_states)
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
+ .transpose(1, 2)
+ .split(3 * [self.head_dim], dim=3)
)
-
- query = self.q_attn(hidden_states)
- key_value = self.c_attn(encoder_hidden_states)
- attention_mask = encoder_attention_mask
- elif self.multi_query:
- query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
- else:
- # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
- # i.e., the memory layout is not the same as GPT2.
- # This makes the concatenation with past_key_value more efficient.
- query, key_value = mint.transpose(
- self.c_attn(hidden_states).view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim), 1, 2
- ).split((self.head_dim, 2 * self.head_dim), dim=3)
if layer_past is not None:
- key_value = mint.cat((layer_past, key_value), dim=-2)
- present = key_value if use_cache else None
-
- key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
-
- if not output_attentions and head_mask is None:
- # Difference with the original implementation: there is no need to transpose the key here,
- # as SDPA expects seq_length to be at index -2 for the key as well
- attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
- else:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"`
- # once this is implemented.
- logger.warning(
- "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional."
- "scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
- " Falling back to the manual attention implementation, but specifying the manual implementation will "
- "be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument "
- '`attn_implementation="eager"` when loading the model.'
- )
- attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not self.is_cross_attention else None
+ key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position})
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if self.is_cross_attention:
+ layer_past.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attn_dropout,
+ scaling=self.scaling,
+ head_mask=head_mask,
+ **kwargs,
+ )
- if not self.multi_query:
- attn_output = mint.transpose(attn_output, 1, 2).reshape(hidden_states.shape)
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
-
- outputs = (attn_output, present)
- if output_attentions:
- if self.multi_query:
- # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
- attn_weights = mint.transpose(attn_weights, 1, 2)
- outputs += (attn_weights,)
-
- return outputs
+ return attn_output, attn_weights
class GPTBigCodeMLP(nn.Cell):
@@ -479,10 +253,10 @@ def __init__(self, intermediate_size, config):
self.c_fc = mint.nn.Linear(embed_dim, intermediate_size)
self.c_proj = mint.nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function]
- self.dropout = mint.nn.Dropout(config.resid_pdrop)
+ self.dropout = nn.Dropout(config.resid_pdrop)
- # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.construct
- def construct(self, hidden_states: Optional[Tuple[ms.Tensor]]) -> ms.Tensor:
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
+ def construct(self, hidden_states: Optional[tuple[ms.Tensor]]) -> ms.Tensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
@@ -490,13 +264,6 @@ def construct(self, hidden_states: Optional[Tuple[ms.Tensor]]) -> ms.Tensor:
return hidden_states
-GPTBIGCODE_ATTENTION_CLASSES = {
- "eager": GPTBigCodeAttention,
- "flash_attention_2": GPTBigCodeFlashAttention2,
- "sdpa": GPTBigCodeSdpaAttention,
-}
-
-
class GPTBigCodeBlock(nn.Cell):
def __init__(self, config, layer_idx=None):
super().__init__()
@@ -505,7 +272,7 @@ def __init__(self, config, layer_idx=None):
self.ln_1 = mint.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
+ self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
self.ln_2 = mint.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
@@ -513,9 +280,7 @@ def __init__(self, config, layer_idx=None):
if config.multi_query:
raise NotImplementedError("Cross-attention not implemented for MQA")
- self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](
- config, is_cross_attention=True, layer_idx=layer_idx
- )
+ self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
self.ln_cross_attn = mint.nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
@@ -523,16 +288,17 @@ def __init__(self, config, layer_idx=None):
def construct(
self,
- hidden_states: Optional[Tuple[ms.Tensor]],
- layer_past: Optional[ms.Tensor] = None,
+ hidden_states: Optional[tuple[ms.Tensor]],
+ layer_past: Optional[Cache] = None,
attention_mask: Optional[ms.Tensor] = None,
head_mask: Optional[ms.Tensor] = None,
encoder_hidden_states: Optional[ms.Tensor] = None,
encoder_attention_mask: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
+ cache_position: Optional[ms.Tensor] = None,
**kwargs,
- ) -> Union[Tuple[ms.Tensor], Tuple[ms.Tensor, ms.Tensor], Tuple[ms.Tensor, ms.Tensor, ms.Tensor]]:
+ ) -> Union[tuple[ms.Tensor], tuple[ms.Tensor, ms.Tensor], tuple[ms.Tensor, ms.Tensor, ms.Tensor]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
@@ -542,6 +308,8 @@ def construct(
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
@@ -564,47 +332,40 @@ def construct(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
- outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
+ outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
- # residual connection
hidden_states = residual + feed_forward_hidden_states
-
- if use_cache:
- outputs = (hidden_states,) + outputs
- else:
- outputs = (hidden_states,) + outputs[1:]
-
- return outputs # hidden_states, present, (attentions, cross_attentions)
+ return (hidden_states,) + outputs
+@auto_docstring
class GPTBigCodePreTrainedModel(PreTrainedModel):
- config_class = GPTBigCodeConfig
+ config: GPTBigCodeConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
- _supports_flash_attn_2 = True
+ _supports_flash_attn = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
- self._supports_dynamic_input = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
- # > A modified initialization which accounts for the accumulation on the residual path with model depth.
- # Scale
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of
- # residual layers.
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
@@ -627,6 +388,7 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)
+@auto_docstring
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -636,12 +398,14 @@ def __init__(self, config):
self.wte = mint.nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = mint.nn.Embedding(config.max_position_embeddings, self.embed_dim)
- self.drop = mint.nn.Dropout(config.embd_pdrop)
+ self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.CellList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = mint.nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
max_positions = config.max_position_embeddings
- self.register_buffer("bias", mint.tril(mint.ones((max_positions, max_positions), dtype=ms.int32)))
+ self.register_buffer(
+ "bias", mint.tril(mint.ones((max_positions, max_positions), dtype=ms.bool_)), persistent=False
+ )
self.gradient_checkpointing = False
@@ -657,13 +421,12 @@ def get_input_embeddings(self):
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
- def register_buffer(self, name, attr):
- setattr(self, name, Parameter(default_input=attr, requires_grad=False))
-
+ @can_return_tuple
+ @auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
- past_key_values: Optional[List[ms.Tensor]] = None,
+ past_key_values: Optional[list[ms.Tensor]] = None,
attention_mask: Optional[ms.Tensor] = None,
token_type_ids: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
@@ -675,11 +438,13 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs,
+ ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
input_ids (`ms.Tensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
- `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
sequence tokens in the vocabulary.
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
@@ -697,10 +462,9 @@ def construct(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.shape
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
@@ -713,76 +477,42 @@ def construct(
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
- if token_type_ids is not None:
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ return_legacy_cache = True
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
- if past_key_values is None:
- past_length = 0
- past_key_values = tuple([None] * len(self.h))
- else:
- past_length = past_key_values[0].shape[-2]
-
- if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.to(ms.int32).cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_length > 0:
- position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
- elif position_ids is None:
- position_ids = mint.arange(past_length, input_shape[-1] + past_length, dtype=ms.int64)
- position_ids = position_ids.unsqueeze(0)
-
- # Self-attention mask.
- query_length = input_shape[-1]
- key_length = past_length + query_length
- self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = mint.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ )
if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
encoder_attention_mask.bool()
if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
else None
)
else:
- # 4d mask is passed through the layers
- if attention_mask is not None:
- self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).bool()
-
- # MQA models: (batch_size, query_length, n_heads, key_length)
- # MHA models: (batch_size, n_heads, query_length, key_length)
- self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
-
- if self._use_sdpa and head_mask is None and not output_attentions:
- # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating
- # point instead of at every layer.
- dtype = self.wte.weight.dtype
- min_dtype = dtype_to_min(dtype)
- self_attention_mask = mint.where(
- self_attention_mask.to(ms.bool_),
- mint.full([], 0.0, dtype=dtype),
- mint.full([], fill_value=min_dtype.item(), dtype=dtype),
- )
-
- # output_attentions=True can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- if self.multi_query:
- # gpt_bigcode using MQA has the bad taste to use a causal mask with shape
- # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
- self_attention_mask = mint.transpose(self_attention_mask, 1, 2)
-
- if query_length > 1 and attention_mask is not None:
- # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention
- # backend
- # produces nans if sequences are completely unattended in the attention mask.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- self_attention_mask = AttentionMaskConverter._unmask_unattended(
- self_attention_mask, min_dtype=min_dtype
- )
-
- attention_mask = self_attention_mask
-
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if (
@@ -803,59 +533,42 @@ def construct(
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
- if inputs_embeds is None:
- inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
-
output_shape = input_shape + (hidden_states.shape[-1],)
- presents = [] if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
- if self.gradient_checkpointing and self.training:
- outputs = self._gradient_checkpointing_func(
- block.__call__,
- hidden_states,
- None,
- attention_mask,
- head_mask[i],
- encoder_hidden_states,
- encoder_attention_mask,
- use_cache,
- output_attentions,
- )
- else:
- outputs = block(
- hidden_states,
- layer_past=layer_past,
- attention_mask=attention_mask,
- head_mask=head_mask[i],
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
+ outputs = block(
+ hidden_states,
+ past_key_values,
+ causal_mask,
+ head_mask[i],
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
hidden_states = outputs[0]
- if use_cache:
- presents.append(outputs[1])
-
if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ all_self_attentions = all_self_attentions + (outputs[1],)
if self.config.add_cross_attention:
- all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+ all_cross_attentions = all_cross_attentions + (outputs[2],)
hidden_states = self.ln_f(hidden_states)
@@ -864,22 +577,24 @@ def construct(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
- if v is not None
- )
+ if return_legacy_cache:
+ past_key_values = past_key_values.to_legacy_cache()
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
- past_key_values=presents,
+ past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
+@auto_docstring(
+ custom_intro="""
+ The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """
+)
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@@ -891,85 +606,11 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
- # Overwritten -- `past_key_values` with uncommon shape
-
- token_type_ids = kwargs.get("token_type_ids", None)
- # Omit tokens covered by past_key_values
- if past_key_values:
- if self.config.multi_query:
- past_length = past_key_values[0].shape[1]
- else:
- past_length = past_key_values[0].shape[2]
-
- # Some generation methods already pass only the last input ID
- if input_ids.shape[1] > past_length:
- remove_prefix_length = past_length
- else:
- # Default to old behavior: keep only final ID
- remove_prefix_length = input_ids.shape[1] - 1
-
- input_ids = input_ids[:, remove_prefix_length:]
- if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
-
- attention_mask = kwargs.get("attention_mask", None)
- position_ids = kwargs.get("position_ids", None)
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.to(ms.int64).cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
- else:
- position_ids = None
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
-
- model_inputs.update(
- {
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "position_ids": position_ids,
- "attention_mask": attention_mask,
- "token_type_ids": token_type_ids,
- }
- )
- return model_inputs
-
- def _get_initial_cache_position(self, input_ids, model_kwargs):
- """
- Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
- Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
- """
- past_length = 0
- if "past_key_values" in model_kwargs:
- if self.config.multi_query:
- past_length = model_kwargs["past_key_values"][0].shape[1]
- else:
- past_length = model_kwargs["past_key_values"][0].shape[2]
- if "inputs_embeds" in model_kwargs:
- cur_len = model_kwargs["inputs_embeds"].shape[1]
- else:
- cur_len = input_ids.shape[-1]
- model_kwargs["cache_position"] = mint.arange(past_length, cur_len)
- return model_kwargs
-
+ @auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None,
+ past_key_values: Optional[tuple[tuple[ms.Tensor]]] = None,
attention_mask: Optional[ms.Tensor] = None,
token_type_ids: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
@@ -982,9 +623,22 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
**kwargs,
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
r"""
+ input_ids (`ms.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
labels (`ms.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
@@ -1006,6 +660,7 @@ def construct(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
hidden_states = transformer_outputs[0]
@@ -1033,16 +688,21 @@ def construct(
cross_attentions=transformer_outputs.cross_attentions,
)
- @staticmethod
- def _reorder_cache(past_key_values: Tuple[Tuple[ms.Tensor]], beam_idx: ms.Tensor) -> Tuple[Tuple[ms.Tensor]]:
- """
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
- beam_idx at every generation step.
- """
- return tuple(layer_past.index_select(0) for layer_past in past_key_values)
+@auto_docstring(
+ custom_intro="""
+ The GPTBigCode Model transformer with a sequence classification head on top (linear layer).
+
+ [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-1) do.
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1053,10 +713,11 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()
+ @auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None,
+ past_key_values: Optional[tuple[tuple[ms.Tensor]]] = None,
attention_mask: Optional[ms.Tensor] = None,
token_type_ids: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
@@ -1067,8 +728,21 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ **kwargs,
+ ) -> Union[tuple, SequenceClassifierOutputWithPast]:
r"""
+ input_ids (`ms.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
labels (`ms.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@@ -1088,6 +762,7 @@ def construct(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ **kwargs,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
@@ -1103,12 +778,12 @@ def construct(
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
- non_pad_mask = ms.Tensor((input_ids != self.config.pad_token_id), ms.int32)
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(ms.int32)
token_indices = mint.arange(input_ids.shape[-1], dtype=ms.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
- logger.warning(
+ logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
@@ -1150,6 +825,7 @@ def construct(
)
+@auto_docstring
class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1162,16 +838,17 @@ def __init__(self, config):
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
- self.dropout = mint.nn.Dropout(classifier_dropout)
+ self.dropout = nn.Dropout(classifier_dropout)
self.classifier = mint.nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
+ @auto_docstring
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None,
+ past_key_values: Optional[tuple[tuple[ms.Tensor]]] = None,
attention_mask: Optional[ms.Tensor] = None,
token_type_ids: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
@@ -1182,8 +859,20 @@ def construct(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TokenClassifierOutput]:
+ ) -> Union[tuple, TokenClassifierOutput]:
r"""
+ input_ids (`ms.Tensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
diff --git a/mindone/transformers/models/gpt_neo/modeling_gpt_neo.py b/mindone/transformers/models/gpt_neo/modeling_gpt_neo.py
index ee6d64e7ea..4934233412 100644
--- a/mindone/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/mindone/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -36,6 +36,7 @@
from ...generation import GenerationMixin
from ...mindspore_adapter import dtype_to_min
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
@@ -45,9 +46,9 @@
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_utils import PreTrainedModel, is_flash_attn_2_available
+from ...modeling_utils import PreTrainedModel
-if is_flash_attn_2_available():
+if is_flash_attn_available():
from ...integrations.flash_attention import flash_attention_forward
diff --git a/mindone/transformers/models/gpt_neox/modeling_gpt_neox.py b/mindone/transformers/models/gpt_neox/modeling_gpt_neox.py
index e4c97a9e11..cc47820235 100644
--- a/mindone/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/mindone/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -7,7 +7,6 @@
from typing import Callable, Optional, Tuple, Union
from transformers.models.gpt_neox.configuration_gpt_neox import GPTNeoXConfig
-from transformers.utils import LossKwargs
import mindspore as ms
from mindspore import Parameter, mint, nn
@@ -27,7 +26,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
-from ...utils import logging
+from ...utils import TransformersKwargs, logging
logger = logging.get_logger(__name__)
@@ -36,10 +35,6 @@ class HybridCache(object):
"""This class do nothing and will be never used in our implement."""
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class GPTNeoXMLP(nn.Cell):
def __init__(self, config):
super().__init__()
@@ -578,7 +573,7 @@ def construct(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/mindone/transformers/models/granite/modeling_granite.py b/mindone/transformers/models/granite/modeling_granite.py
index 584375a0ab..05c70238bc 100644
--- a/mindone/transformers/models/granite/modeling_granite.py
+++ b/mindone/transformers/models/granite/modeling_granite.py
@@ -26,7 +26,6 @@
from transformers.models.granite.configuration_granite import GraniteConfig
from transformers.utils import (
- LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@@ -46,6 +45,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GraniteConfig"
@@ -749,10 +749,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -801,7 +797,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/mindone/transformers/models/helium/modeling_helium.py b/mindone/transformers/models/helium/modeling_helium.py
index b5200fe302..338e57cc7b 100644
--- a/mindone/transformers/models/helium/modeling_helium.py
+++ b/mindone/transformers/models/helium/modeling_helium.py
@@ -3,7 +3,6 @@
from transformers import HeliumConfig
from transformers.utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -32,6 +31,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
logger = logging.get_logger(__name__)
@@ -726,10 +726,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -779,7 +775,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
diff --git a/mindone/transformers/models/idefics/processing_idefics.py b/mindone/transformers/models/idefics/processing_idefics.py
index ec9ecb3b59..c2bf332325 100644
--- a/mindone/transformers/models/idefics/processing_idefics.py
+++ b/mindone/transformers/models/idefics/processing_idefics.py
@@ -30,14 +30,7 @@
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
-from ...processing_utils import (
- ImagesKwargs,
- ProcessingKwargs,
- ProcessorMixin,
- TextKwargs,
- Unpack,
- _validate_images_text_input_order,
-)
+from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
IMAGE_TOKEN = ""
@@ -297,8 +290,6 @@ def __call__(
"""
if images is None and text is None:
raise ValueError("You need to specify either `text` or `images` and `text`.")
- # check if images and text inputs are reversed for BC
- images, text = _validate_images_text_input_order(images, text)
if images is None:
# assuming the user wants to use the old behavior with prompts as the only argument
diff --git a/mindone/transformers/models/llama/modeling_llama.py b/mindone/transformers/models/llama/modeling_llama.py
index 57e085764c..1fcfd86455 100644
--- a/mindone/transformers/models/llama/modeling_llama.py
+++ b/mindone/transformers/models/llama/modeling_llama.py
@@ -24,7 +24,7 @@
import numpy as np
from transformers import LlamaConfig
-from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
import mindspore as ms
from mindspore import Parameter, Tensor, mint, nn, ops
@@ -41,6 +41,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
+from ...utils import TransformersKwargs
logger = logging.get_logger(__name__)
@@ -387,6 +388,8 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_quantized_cache = False
_supports_static_cache = False # StaticCache, not used
_supports_attention_backend = True
+ _supports_jit = True
+ _is_stateful = True
def _init_weights(self, cell):
std = self.config.initializer_range
@@ -756,10 +759,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
- ...
-
-
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -815,7 +814,7 @@ def construct(
return_dict: Optional[bool] = False,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **kwargs: Unpack[TransformersKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
diff --git a/mindone/transformers/models/llava/processing_llava.py b/mindone/transformers/models/llava/processing_llava.py
index a6b29f75ac..be0162c9e6 100644
--- a/mindone/transformers/models/llava/processing_llava.py
+++ b/mindone/transformers/models/llava/processing_llava.py
@@ -21,10 +21,11 @@
from typing import List, Optional, Union
+from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
-from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
diff --git a/mindone/transformers/models/llava_next/modeling_llava_next.py b/mindone/transformers/models/llava_next/modeling_llava_next.py
index 575474bd1b..7aa9cca16a 100644
--- a/mindone/transformers/models/llava_next/modeling_llava_next.py
+++ b/mindone/transformers/models/llava_next/modeling_llava_next.py
@@ -32,12 +32,15 @@
from mindone.models.utils import normal_, zeros_
from ...activations import ACT2FN
+from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
-from ...modeling_outputs import ModelOutput
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import logging
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
logger = logging.get_logger(__name__)
@@ -151,6 +154,23 @@ def unpad_image(tensor, original_size):
return unpadded_tensor
+@dataclass
+class LlavaNextModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`ms.Tensor`, *optional*):
+ A `ms.Tensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[ms.Tensor] = None
+
+
@dataclass
class LlavaNextCausalLMOutputWithPast(ModelOutput):
"""
@@ -250,7 +270,9 @@ def _init_weights(self, module):
module.weight[module.padding_idx] = 0
-class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
+class LlavaNextModel(LlavaNextPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(self, config: LlavaNextConfig):
super().__init__(config)
# TODO: remove the config fix once they are fixed.
@@ -266,7 +288,7 @@ def __init__(self, config: LlavaNextConfig):
# TODO: remove the config fix once they are fixed.
config.text_config._attn_implementation = config._attn_implementation
config.text_config.mindspore_dtype = getattr(config, "mindspore_dtype", None)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
@@ -437,61 +459,28 @@ def get_image_features(
def construct(
self,
- input_ids: Optional[ms.Tensor] = None,
- pixel_values: Optional[ms.Tensor] = None,
+ input_ids: ms.Tensor = None,
+ pixel_values: ms.Tensor = None,
image_sizes: Optional[ms.Tensor] = None,
attention_mask: Optional[ms.Tensor] = None,
position_ids: Optional[ms.Tensor] = None,
- past_key_values: Optional[List[ms.Tensor]] = None,
+ past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[ms.Tensor] = None,
- vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
- labels: Optional[ms.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
- logits_to_keep: Union[int, ms.Tensor] = 0,
- **lm_kwargs,
- ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, LlavaNextModelOutputWithPast]:
r"""
- labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- logits_to_keep (`int` or `ms.Tensor`, *optional*):
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
- If a `ms.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
- This is useful when using packed tensor format (single dimension for batch and sequence length).
-
- Returns:
-
- Example:
-
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from mindone.transformers import AutoProcessor, LlavaNextForConditionalGeneration
-
- >>> model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
- >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
-
- >>> prompt = "[INST] \nWhat is shown in this image? [/INST]"
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
-
- >>> inputs = processor(images=image, text=prompt, return_tensors="ms")
-
- >>> # Generate
- >>> generate_ids = model.generate(**inputs, max_new_tokens=30)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
- ```"""
-
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -557,31 +546,181 @@ def construct(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
- **lm_kwargs,
+ **kwargs,
+ )
+
+ return LlavaNextModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^image_newline": "model.image_newline",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaNextConfig):
+ super().__init__(config)
+ self.model = LlavaNextModel(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Cell:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ return self.model.pack_image_features(
+ image_features=image_features,
+ image_sizes=image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=image_newline,
+ )
+
+ def get_image_features(
+ self,
+ pixel_values: ms.Tensor,
+ image_sizes: ms.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ def construct(
+ self,
+ input_ids: Optional[ms.Tensor] = None,
+ pixel_values: Optional[ms.Tensor] = None,
+ image_sizes: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[List[ms.Tensor]] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ logits_to_keep: Union[int, ms.Tensor] = 0,
+ **kwargs,
+ ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
+ r"""
+ labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `ms.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `ms.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from mindone.transformers import AutoProcessor, LlavaNextForConditionalGeneration
+
+ >>> model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
+
+ >>> prompt = "[INST] \nWhat is shown in this image? [/INST]"
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="ms")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ outputs = self.model(
+ input_ids,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :]
- shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = mint.nn.CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
return LlavaNextCausalLMOutputWithPast(
loss=loss,
@@ -589,7 +728,7 @@ def construct(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
+ image_hidden_states=outputs,
)
def prepare_inputs_for_generation(
diff --git a/mindone/transformers/models/llava_next/processing_llava_next.py b/mindone/transformers/models/llava_next/processing_llava_next.py
index de1fc14339..8d67f00419 100644
--- a/mindone/transformers/models/llava_next/processing_llava_next.py
+++ b/mindone/transformers/models/llava_next/processing_llava_next.py
@@ -22,14 +22,13 @@
from typing import List, Union
+import numpy as np
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
-import mindspore as ms
-
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, get_image_size, to_numpy_array
-from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...utils import logging
logger = logging.get_logger(__name__)
@@ -39,6 +38,7 @@ class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
+ "return_mm_token_type_ids": False,
},
"images_kwargs": {
"do_pad": True,
@@ -62,7 +62,7 @@ class LlavaNextProcessor(ProcessorMixin):
Patch size from the vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
- Shoudl be same as in model's config
+ Should be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `""`):
@@ -73,13 +73,6 @@ class LlavaNextProcessor(ProcessorMixin):
"""
attributes = ["image_processor", "tokenizer"]
- valid_kwargs = [
- "chat_template",
- "patch_size",
- "vision_feature_select_strategy",
- "image_token",
- "num_additional_image_tokens",
- ]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
@@ -98,6 +91,11 @@ def __init__(
self.num_additional_image_tokens = num_additional_image_tokens
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
+ self.image_token_id = (
+ tokenizer.image_token_id
+ if getattr(tokenizer, "image_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.image_token)
+ )
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
@@ -112,7 +110,7 @@ def __call__(
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
- LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+ LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
@@ -135,8 +133,6 @@ def __call__(
"""
if images is None and text is None:
raise ValueError("You have to specify at least images or text.")
- # check if images and text inputs are reversed for BC
- images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
LlavaNextProcessorKwargs,
@@ -151,7 +147,7 @@ def __call__(
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
- raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
prompt_strings = text
if image_inputs:
@@ -172,12 +168,18 @@ def __call__(
prompt_strings.append(sample)
prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings]
- output_kwargs["text_kwargs"].pop("return_tensors", None)
- text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors="np")
- for k, v in text_inputs.items():
- text_inputs[k] = ms.tensor(v)
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
- return BatchFeature(data={**text_inputs, **image_inputs})
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
@@ -221,6 +223,48 @@ def _get_unpadded_features(self, height, width, patches_height, patches_width, s
newline_features = current_height
return (unpadded_features, newline_features)
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (list[list[str]], *optional*):
+ The input sizes formatted as (height, width) per each image.
+ video_sizes (list[list[str]], *optional*):
+ The input sizes formatted as (num_frames, height, width) per each video.
+ audio_lengths (list[int], *optional*):
+ The input length formatted as per each audio.
+ Returns:
+ dict[str, list[int]]: A dictionary mapping each modality ("image", "video", "audio")
+ to a list containing the number of placeholder tokens required. If the model doesn't accept
+ a certain modality or no input sizes are provided, the dict value is set to an empty list.
+ """
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = LlavaNextProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+
+ size = images_kwargs.get("size", None) or self.image_processor.size
+ size = (
+ (size["shortest_edge"], size["shortest_edge"])
+ if "shortest_edge" in size
+ else (min(size["height"], size["width"]), min(size["height"], size["width"]))
+ )
+ processed_height, processed_width = size
+
+ batch_num_image_tokens = []
+ num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch`
+ for image_size in image_sizes:
+ orig_height, orig_width = image_size
+ num_image_tokens = self._get_number_of_features(
+ orig_height, orig_width, processed_height, processed_width
+ )
+ if self.vision_feature_select_strategy == "default":
+ num_image_tokens -= 1
+ batch_num_image_tokens.append(num_image_tokens)
+ vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
diff --git a/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py b/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py
index 3981c8944f..91b7af98e3 100644
--- a/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py
+++ b/mindone/transformers/models/llava_next_video/image_processing_llava_next_video.py
@@ -30,15 +30,14 @@
ChannelDimension,
ImageInput,
PILImageResampling,
- VideoInput,
infer_channel_dimension_format,
is_scaled_image,
- make_batched_videos,
make_list_of_images,
to_numpy_array,
validate_preprocess_arguments,
)
from ...utils import TensorType, logging
+from ...video_utils import VideoInput, make_batched_videos
logger = logging.get_logger(__name__)
diff --git a/mindone/transformers/models/llava_next_video/modeling_llava_next_video.py b/mindone/transformers/models/llava_next_video/modeling_llava_next_video.py
index 4a22d72532..211b419ba8 100644
--- a/mindone/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/mindone/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -31,18 +31,43 @@
from mindone.models.utils import normal_, zeros_
from ...activations import ACT2FN
+from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
-from ...modeling_outputs import ModelOutput
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import logging
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
+@dataclass
+class LlavaNextVideoModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`ms.Tensor`, *optional*):
+ A `ms.Tensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ video_hidden_states (`ms.Tensor`, *optional*):
+ A `ms.Tensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
+ video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[ms.Tensor] = None
+
+ video_hidden_states: Optional[ms.Tensor] = None
+
+
@dataclass
class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
"""
@@ -286,7 +311,9 @@ def unpad_image(tensor, original_size):
return unpadded_tensor
-class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
+class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+
def __init__(
self,
config: LlavaNextVideoConfig,
@@ -305,7 +332,7 @@ def __init__(
# TODO: remove the config fix once they are fixed.
config.text_config._attn_implementation = config._attn_implementation
config.text_config.mindspore_dtype = getattr(config, "mindspore_dtype", None)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
@@ -468,6 +495,245 @@ def get_image_features(
image_features = mint.split(image_features, image_num_patches, dim=0)
return image_features
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ pixel_values: ms.Tensor = None,
+ pixel_values_videos: ms.Tensor = None,
+ image_sizes: Optional[ms.Tensor] = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, LlavaNextVideoModelOutputWithPast]:
+ r"""
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ self.vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ self.vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
+ "and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None and pixel_values.shape[0] > 0:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=self.vision_feature_layer,
+ vision_feature_select_strategy=self.vision_feature_select_strategy,
+ )
+ image_features, feature_lens = self.pack_image_features(
+ image_features,
+ image_sizes,
+ self.vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds)
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0]
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ image_features = image_features.to(inputs_embeds.dtype)
+ # TODO: remove cast
+ inputs_embeds = (
+ inputs_embeds.float().masked_scatter(special_image_mask, image_features.float()).to(inputs_embeds.dtype)
+ )
+
+ if pixel_values_videos is not None and pixel_values_videos.shape[0] > 0:
+ video_features = self.get_video_features(
+ pixel_values_videos,
+ vision_feature_layer=self.vision_feature_layer,
+ vision_feature_select_strategy=self.vision_feature_select_strategy,
+ )
+ video_features = [feature.flatten(0, 1) for feature in video_features]
+ video_feature_lens = [feature.shape[0] for feature in video_features]
+ video_features = mint.cat(video_features, dim=0)
+ video_feature_lens = ms.tensor(video_feature_lens, dtype=ms.int64)
+
+ special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds)
+ if inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
+ n_video_features = video_features.shape[0]
+ raise ValueError(
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
+ )
+ video_features = video_features.to(inputs_embeds.dtype)
+ # TODO: remove cast
+ inputs_embeds = (
+ inputs_embeds.float().masked_scatter(special_image_mask, video_features.float()).to(inputs_embeds.dtype)
+ )
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return LlavaNextVideoModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ video_hidden_states=video_features if pixel_values_videos is not None else None,
+ )
+
+ def get_video_features(
+ self,
+ pixel_values: ms.Tensor,
+ vision_feature_layer: Union[int, List[int]],
+ vision_feature_select_strategy: str,
+ ):
+ """
+ Obtains video last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`ms.Tensor]` of shape `(batch_size, num_frames, channels, height, width)`)
+ The tensors corresponding to the input video.
+ vision_feature_layer (`Union[int, List[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ video_features (List[`ms.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
+ and are of shape `(num_videos, video_length, embed_dim)`).
+ """
+ batch_size, frames, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
+ video_features = self.vision_tower(pixel_values, output_hidden_states=True, return_dict=True)
+
+ # If we have one vision feature layer, return the corresponding hidden states,
+ # otherwise, select the hidden states of each feature layer and concatenate them
+ if isinstance(vision_feature_layer, int):
+ selected_video_features = video_features.hidden_states[vision_feature_layer]
+ else:
+ hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
+ selected_video_features = mint.cat(hs_pool, dim=-1)
+
+ if vision_feature_select_strategy == "default":
+ selected_video_features = selected_video_features[:, 1:]
+ elif vision_feature_select_strategy == "full":
+ selected_video_features = selected_video_features
+
+ # Same as image features except that video has pooling layer
+ video_features = self.vision_resampler(selected_video_features)
+ video_features = self.multi_modal_projector(video_features)
+ video_features = mint.split(video_features, frames, dim=0)
+ return video_features
+
+
+class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^image_newline": "model.image_newline",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: LlavaNextVideoConfig):
+ super().__init__(config)
+ self.model = LlavaNextVideoModel(config)
+ self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Cell:
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ return self.model.pack_image_features(
+ image_features=image_features,
+ image_sizes=image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=image_newline,
+ )
+
+ def get_image_features(
+ self,
+ pixel_values: ms.Tensor,
+ image_sizes: ms.Tensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
def construct(
self,
input_ids: Optional[ms.Tensor] = None,
@@ -487,7 +753,7 @@ def construct(
return_dict: Optional[bool] = None,
cache_position: Optional[ms.Tensor] = None,
logits_to_keep: Union[int, ms.Tensor] = 0,
- **lm_kwargs,
+ **kwargs,
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
pixel_values_videos (`mindspore.Tensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)):
@@ -569,115 +835,44 @@ def construct(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self.vision_feature_layer = (
+ vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
- self.vision_feature_select_strategy = (
+ vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
- "and must specify either one"
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if pixel_values is not None and pixel_values.shape[0] > 0:
- image_features = self.get_image_features(
- pixel_values,
- image_sizes,
- vision_feature_layer=self.vision_feature_layer,
- vision_feature_select_strategy=self.vision_feature_select_strategy,
- )
- image_features, feature_lens = self.pack_image_features(
- image_features,
- image_sizes,
- self.vision_feature_select_strategy,
- image_newline=self.image_newline,
- )
-
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds)
- if inputs_embeds[special_image_mask].numel() != image_features.numel():
- n_image_tokens = (input_ids == self.config.image_token_index).sum()
- n_image_features = image_features.shape[0]
- raise ValueError(
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
- )
- image_features = image_features.to(inputs_embeds.dtype)
- # TODO: remove cast
- inputs_embeds = (
- inputs_embeds.float().masked_scatter(special_image_mask, image_features.float()).to(inputs_embeds.dtype)
- )
-
- if pixel_values_videos is not None and pixel_values_videos.shape[0] > 0:
- video_features = self.get_video_features(
- pixel_values_videos,
- vision_feature_layer=self.vision_feature_layer,
- vision_feature_select_strategy=self.vision_feature_select_strategy,
- )
- video_features = [feature.flatten(0, 1) for feature in video_features]
- video_feature_lens = [feature.shape[0] for feature in video_features]
- video_features = mint.cat(video_features, dim=0)
- video_feature_lens = ms.tensor(video_feature_lens, dtype=ms.int64)
-
- special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds)
- if inputs_embeds[special_image_mask].numel() != video_features.numel():
- n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
- n_video_features = video_features.shape[0]
- raise ValueError(
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
- )
- video_features = video_features.to(inputs_embeds.dtype)
- # TODO: remove cast
- inputs_embeds = (
- inputs_embeds.float().masked_scatter(special_image_mask, video_features.float()).to(inputs_embeds.dtype)
- )
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
- return_dict=return_dict,
+ return_dict=True,
cache_position=cache_position,
- logits_to_keep=logits_to_keep,
- **lm_kwargs,
+ image_sizes=image_sizes,
+ **kwargs,
)
- logits = outputs[0]
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :]
- shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = mint.nn.CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,
@@ -685,8 +880,8 @@ def construct(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- video_hidden_states=video_features if pixel_values_videos is not None else None,
+ image_hidden_states=outputs.image_hidden_states,
+ video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@@ -726,51 +921,5 @@ def prepare_inputs_for_generation(
return model_inputs
- def get_video_features(
- self,
- pixel_values: ms.Tensor,
- vision_feature_layer: Union[int, List[int]],
- vision_feature_select_strategy: str,
- ):
- """
- Obtains video last hidden states from the vision tower and apply multimodal projection.
-
- Args:
- pixel_values (`ms.Tensor]` of shape `(batch_size, num_frames, channels, height, width)`)
- The tensors corresponding to the input video.
- vision_feature_layer (`Union[int, List[int]]`):
- The index of the layer to select the vision feature. If multiple indices are provided,
- the vision feature of the corresponding indices will be concatenated to form the
- vision features.
- vision_feature_select_strategy (`str`):
- The feature selection strategy used to select the vision feature from the vision backbone.
- Can be one of `"default"` or `"full"`
- Returns:
- video_features (List[`ms.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
- and are of shape `(num_videos, video_length, embed_dim)`).
- """
- batch_size, frames, channels, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
- video_features = self.vision_tower(pixel_values, output_hidden_states=True, return_dict=True)
-
- # If we have one vision feature layer, return the corresponding hidden states,
- # otherwise, select the hidden states of each feature layer and concatenate them
- if isinstance(vision_feature_layer, int):
- selected_video_features = video_features.hidden_states[vision_feature_layer]
- else:
- hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
- selected_video_features = mint.cat(hs_pool, dim=-1)
-
- if vision_feature_select_strategy == "default":
- selected_video_features = selected_video_features[:, 1:]
- elif vision_feature_select_strategy == "full":
- selected_video_features = selected_video_features
-
- # Same as image features except that video has pooling layer
- video_features = self.vision_resampler(selected_video_features)
- video_features = self.multi_modal_projector(video_features)
- video_features = mint.split(video_features, frames, dim=0)
- return video_features
-
__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
diff --git a/mindone/transformers/models/llava_next_video/processing_llava_next_video.py b/mindone/transformers/models/llava_next_video/processing_llava_next_video.py
index 417a1bcded..d4fb66609f 100644
--- a/mindone/transformers/models/llava_next_video/processing_llava_next_video.py
+++ b/mindone/transformers/models/llava_next_video/processing_llava_next_video.py
@@ -25,13 +25,12 @@
import numpy as np
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
-import mindspore as ms
-
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
-from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
-from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
+from ...image_utils import ImageInput, get_image_size, to_numpy_array
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...utils import logging
+from ...video_utils import VideoInput
logger = logging.get_logger(__name__)
@@ -57,7 +56,7 @@ class LlavaNextVideoProcessor(ProcessorMixin):
[`LlamaTokenizerFast`]. See the [`~LlavaNextVideoProcessor.__call__`] and [`~LlavaNextVideoProcessor.decode`] for more information.
Args:
- video_processor ([`LlavaNextVideoImageProcessor`], *optional*):
+ video_processor ([`LlavaNextVideoVideoProcessor`], *optional*):
The video processor is a required input.
image_processor ([`LlavaNextImageProcessor`], *optional*):
The image processor is a required input.
@@ -69,7 +68,7 @@ class LlavaNextVideoProcessor(ProcessorMixin):
Patch size from the vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
- Shoudl be same as in model's config
+ Should be same as in model's config
video_token (`str`, *optional*, defaults to `"