Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mindnlp/core/ops/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,13 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
if USE_PYBOOST:
return mindspore.mint.unique(input, sorted, return_inverse, return_counts, dim)
out, inverse = ops.unique(input)
outs = (out,)
if return_inverse:
return out, inverse
return out
outs += (inverse,)
if return_counts:
counts = (out == input).sum(0, keepdims=True)
outs += (counts,)
return outs if len(outs) > 1 else outs[0]

# unique_consecutive
def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None):
Expand Down
8 changes: 4 additions & 4 deletions mindnlp/core/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def step(self, grads):
loss = None
start = 0
for group in self.param_groups:
weight_decay = group['weight_decay']
weight_decay = float(group['weight_decay'])
momentum = Tensor(group['momentum'], mindspore.float32)
lr = Tensor(group['lr'], mindspore.float32)
dampening = group['dampening']
dampening = float(group['dampening'])
nesterov = group['nesterov']
maximize=group["maximize"]

Expand All @@ -85,8 +85,8 @@ def step(self, grads):
# d_p = buf
# new_p = p.add(d_p, alpha=-group['lr'])
# assign(p, new_p)
stat = ops.ones_like(p)
accum = ops.zeros_like(p)
stat = Tensor(ops.ones_like(p))
accum = Tensor(ops.zeros_like(p))
ops.optim.raw_sgd(p, d_p, lr, dampening, weight_decay, nesterov, accum, momentum, stat)

return loss
Original file line number Diff line number Diff line change
Expand Up @@ -969,8 +969,8 @@ def forward(
if labels is not None:
# retrieve loss input_lengths from attention_mask
labels = labels.astype(mindspore.int32)
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# if labels.max() >= self.config.vocab_size:
# raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
attention_mask = (
attention_mask if attention_mask is not None else ops.ones_like(input_values, dtype=mindspore.int64)
)
Expand Down
151 changes: 71 additions & 80 deletions mindnlp/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from mindnlp.core.nn import functional as F
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ....utils import (
Expand All @@ -35,6 +34,56 @@

_CONFIG_FOR_DOC = "DbrxConfig"

# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: mindspore.Tensor,
sequence_length: int,
target_length: int,
dtype: mindspore.dtype,
min_dtype: float,
cache_position: mindspore.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

Args:
attention_mask (`mindspore.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`mindspore.dtype`):
The dtype to use for the 4D attention mask.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`mindspore.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`mindspore.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.ndim == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = ops.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype)
if sequence_length != 1:
causal_mask = ops.triu(causal_mask, diagonal=1)
causal_mask *= ops.arange(target_length) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1))
if attention_mask is not None:
causal_mask = causal_mask.copy() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

return causal_mask


# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx
class DbrxRotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -659,8 +708,6 @@ def forward(
w2_chunked = [(ops.squeeze(w2,dim=0)) for w2 in w2_chunked]
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = ops.nonzero(expert_mask[expert_idx], as_tuple=True)
topk_idx = topk_idx.astype(mindspore.int32)
token_idx = token_idx.astype(mindspore.int32)
if token_idx.shape[0] == 0:
continue

Expand All @@ -678,7 +725,7 @@ def forward(
* top_weights[token_list, topk_list, None]
)

out.index_add(0, token_idx, expert_out)
out = out.index_add(0, token_idx.int(), expert_out)
out = out.reshape(bsz, q_len, hidden_size)
return out

Expand Down Expand Up @@ -1123,95 +1170,40 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None

dtype = input_tensor.dtype
min_dtype = ops.finfo(dtype).min
min_dtype = float(ops.finfo(dtype).min)
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
if isinstance(attention_mask, mindspore.Tensor):
target_length = attention_mask.shape[-1]
else:
target_length = past_seen_tokens + sequence_length + 1

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError(
"Custom 4D attention mask should be passed in inverted form with max==0`"
)
causal_mask = attention_mask
else:
causal_mask = ops.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype
)
if sequence_length != 1:
causal_mask = ops.triu(causal_mask, diagonal=1)
causal_mask *= ops.arange(target_length) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].broadcast_to(
(input_tensor.shape[0], 1, -1, -1)
)
if attention_mask is not None:
causal_mask = ops.clone(
causal_mask
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, mindspore.Tensor)
else past_seen_tokens + sequence_length + 1
)

# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

return causal_mask



class DbrxForCausalLM(DbrxPreTrainedModel):
def __init__(self, config: DbrxConfig):
super().__init__(config)
Expand Down Expand Up @@ -1388,9 +1380,8 @@ def prepare_inputs_for_generation(

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
# position_ids.masked_fill_(attention_mask == 0, 1)
ops.masked_fill(position_ids, attention_mask == 0, value=1)
position_ids = attention_mask.int().cumsum(-1) - 1
position_ids = position_ids.masked_fill(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

Expand Down
Loading