Skip to content

fix diffusers unet ut #2081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def _expand(self, *size):
Tensor.expand = _expand
StubTensor.expand = _expand

Tensor.broadcast_to = ops.broadcast_to
StubTensor.broadcast_to = ops.broadcast_to

def clone(self, *args, **kwargs):
return self.copy()

Expand Down
7 changes: 7 additions & 0 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,13 @@ def pad(input, pad, mode='constant', value=None):
if input.dtype == mindspore.bool_:
input = input.to(mindspore.int32)
return ops.pad(input, new_pad, mode, value).to(mindspore.bool_)
if input.ndim > 5 and mode == 'constant':
paddings = ()
for i in range(0, len(new_pad), 2):
paddings += (new_pad[i: i+2],)

paddings = ((0, 0),) * (input.ndim - len(paddings)) + tuple(reversed(paddings))
return _get_cache_prim(ops.Pad)(paddings)(input)
return ops.pad(input, new_pad, mode, value)

def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'):
Expand Down
18 changes: 18 additions & 0 deletions mindnlp/core/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,24 @@ def __delattr__(self, name):
else:
super().__delattr__(name)

def _register_state_dict_hook(self, hook):
r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method.

It should have the following signature::
hook(module, state_dict, prefix, local_metadata) -> None or state_dict

The registered hooks can modify the ``state_dict`` inplace or return a new one.
If a new ``state_dict`` is returned, it will only be respected if it is the root
module that :meth:`~nn.Module.state_dict` is called from.
"""
if getattr(hook, "_from_public_api", False):
raise RuntimeError(
"Cannot register the same function as the state dict post hook that was "
"previously registered via register_state_dict_post_hook"
)
handle = RemovableHandle(self._state_dict_hooks)
self._state_dict_hooks[handle.id] = hook
return handle

def extra_repr(self) -> str:
r"""Set the extra representation of the module.
Expand Down
4 changes: 3 additions & 1 deletion mindnlp/core/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def manual_expand(tensor, shape):
has_broadcast_to = hasattr(mindspore.mint, "broadcast_to")


def broadcast_to(input, shape):
def broadcast_to(input, *shape):
if isinstance(shape[0], tuple):
shape = shape[0]
if ON_ORANGE_PI and not use_pyboost():
# return input.expand(mindspore.tensor(shape))
return manual_expand(input, shape)
Expand Down
Loading