Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 10 additions & 0 deletions mindnlp/utils/safetensors_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ def get_tensor(self, name):
def get_slice(self, name):
return self.tensors[name]

def offset_keys(self):
"""
Returns the names of the tensors in the file, ordered by offset.
Returns:
(`List[str]`):
The name of the tensors contained in that file
"""
return self.keys()


def safe_save_file(tensor_dict, filename, metadata=None):
"""
Expand Down
1 change: 1 addition & 0 deletions mindtorch/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ON_ORANGE_PI = '310b' in SOC
DEFAULT_DTYPE = mindspore.float32
MS27 = '.'.join(mindspore.__version__.split('.')[:2]) >= '2.7'
FLASH_ATTN_MASK_VALID = int(os.environ.get('FLASH_ATTN_MASK_VALID', 1))

# OP backend select
USE_PYBOOST = True
Expand Down
9 changes: 6 additions & 3 deletions mindtorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mindtorch._C import default_generator
from mindtorch.nn.modules.utils import _pair

from ..configs import ON_A2, ON_A1
from ..configs import ON_A2, ON_A1, FLASH_ATTN_MASK_VALID

generator_step_ = 12

Expand Down Expand Up @@ -1162,9 +1162,12 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale


if query.device.type == 'npu' and ON_A2:
if query.dtype != mindtorch.float32 and query.device.type == 'npu' and ON_A2 and:
if attn_mask is not None and not is_causal:
attn_mask = ~attn_mask
if FLASH_ATTN_MASK_VALID == 1:
attn_mask = ~attn_mask
else:
attn_mask = attn_mask.bool()

sparse_mode = 0

Expand Down
Loading