Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ def binary_cross_entropy_with_logits(input, target, weight=None, reduction='mean
return mindspore.mint.nn.functional.binary_cross_entropy_with_logits(input, target, weight, reduction, pos_weight)
return ops.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)

def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:
if eps != 1e-10:
warnings.warn("`eps` parameter is deprecated and has no effect.")

uniform_samples = _get_cache_prim(ops.UniformReal)()(logits.shape)
gumbels = -ops.log(-ops.log(uniform_samples + eps) + eps) # ~Gumbel(0, 1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = softmax(gumbels, dim)

if hard:
# Straight through.
index = y_soft.argmax(dim)
y_hard = one_hot(index, logits.shape[dim])
ret = ops.stop_gradient(y_hard - y_soft) + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret

def log_softmax(input, dim=-1, dtype=None):
out = ops.log_softmax(input, dim)
if dtype is not None:
Expand Down Expand Up @@ -791,7 +810,7 @@ def multi_head_attention_forward(
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
broadcast_to((-1, num_heads, -1, -1)).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
else:
Expand Down
9 changes: 8 additions & 1 deletion mindnlp/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import Tensor, Parameter
from mindspore.train.serialization import _exec_save, _parse_ckpt_proto, tensor_to_np_type, tensor_to_ms_type

import safetensors
Expand Down Expand Up @@ -756,6 +756,13 @@ def _open_zipfile_writer(name_or_buffer):
container = _open_zipfile_writer_buffer
return container(name_or_buffer)

def _rebuild_parameter(data, requires_grad, backward_hooks):
param = Parameter(data, requires_grad=requires_grad)
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
return param

def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
'''Rebuilds a tensor based on the provided parameters.

Expand Down
2 changes: 1 addition & 1 deletion mindnlp/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa

return load_result

def get_nb_trainable_parameters(self) -> tuple[int, int]:
def get_nb_trainable_parameters(self):
r"""
Returns the number of trainable parameters and the number of all parameters in the model.
"""
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
return 0
return (ops.any(self.self_attention_cache.key_cache[layer_idx][0, 0], dim=-1)).sum().item()
return (ops.any(self.self_attention_cache.key_cache[layer_idx][0, 0].bool(), dim=-1)).sum().item()

def reset(self):
if hasattr(self.self_attention_cache, "reset"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import mindspore
from mindspore.common.initializer import Uniform, HeNormal, initializer,Normal

from mindnlp.core import nn, ops
from mindnlp.core import nn, ops, no_grad
from mindnlp.core.nn import functional as F
from mindnlp.utils import logging
from ...activations import ACT2FN
Expand Down Expand Up @@ -969,8 +969,9 @@ def forward(
if labels is not None:
# retrieve loss input_lengths from attention_mask
labels = labels.astype(mindspore.int32)
# if labels.max() >= self.config.vocab_size:
# raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
with no_grad():
if ops.max(labels) >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
attention_mask = (
attention_mask if attention_mask is not None else ops.ones_like(input_values, dtype=mindspore.int64)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1687,9 +1687,9 @@ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes)
proposals.append(proposal)
_cur += height * width
output_proposals = ops.cat(proposals, 1)
output_proposals_valid = (
(output_proposals > 0.01).int() & (output_proposals < 0.99).int()
).all(-1, keep_dims=True)
output_proposals_valid = ops.all(
((output_proposals > 0.01).int() & (output_proposals < 0.99).int()).bool(), -1, keepdim=True
)
output_proposals = ops.log(
output_proposals / (1 - output_proposals)
) # inverse sigmoid
Expand Down Expand Up @@ -2291,8 +2291,9 @@ def loss_labels(self, outputs, targets, indices, num_boxes):
source_logits.shape[2] + 1,
dtype=source_logits.dtype,
)
target_classes = target_classes.unsqueeze(-1)
target_classes_onehot = ops.scatter(
target_classes_onehot, 2, target_classes.unsqueeze(-1), ops.ones_like(target_classes_onehot)
target_classes_onehot, 2, target_classes, ops.ones_like(target_classes, dtype=target_classes_onehot.dtype)
)
target_classes_onehot = target_classes_onehot[:, :, :-1]
loss_ce = (
Expand Down
14 changes: 2 additions & 12 deletions mindnlp/transformers/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,11 @@
''' Wav2Vec2 Model '''

from . import configuration_wav2vec2, feature_extraction_wav2vec2, processing_wav2vec2, tokenization_wav2vec2, modeling_wav2vec2
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from .configuration_wav2vec2 import *
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForAudioFrameClassification,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2ForXVector,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
from .modeling_wav2vec2 import *

__all__ = []
__all__.extend(configuration_wav2vec2.__all__)
Expand Down
Loading
Loading