Skip to content

Commit 188c0b4

Browse files
authored
fix diffusers pipelines d class ut (#2085)
1 parent c1b22f8 commit 188c0b4

File tree

5 files changed

+198
-12
lines changed

5 files changed

+198
-12
lines changed

mindnlp/core/_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,9 @@ def __contains__(self, item):
678678
Tensor.roll = ops.roll
679679
StubTensor.roll = ops.roll
680680

681+
Tensor.bernoulli_ = ops.inplace_bernoulli
682+
StubTensor.bernoulli_ = ops.inplace_bernoulli
683+
681684

682685
def _rebuild_from_type_v2(func, new_type, args, state):
683686
ret = func(*args)

mindnlp/core/nn/functional.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""nn functional"""
22
import math
3+
import numbers
34
import warnings
45
from typing import Optional, Tuple, List
56
import numpy as np
@@ -547,7 +548,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
547548
"Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
548549
"output size in (o1, o2, ...,oK) format."
549550
)
550-
output_size = size
551+
output_size = [s.item() if not isinstance(s, numbers.Number) else s for s in size]
551552
else:
552553
output_size = [size for _ in range(dim)]
553554
elif scale_factor is not None:
@@ -637,10 +638,10 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
637638
)
638639
if input.dim() == 4 and mode == "bilinear":
639640
assert align_corners is not None
640-
if antialias:
641-
return torch._C._nn._upsample_bilinear2d_aa(
642-
input, output_size, align_corners, scale_factors
643-
)
641+
# if antialias:
642+
# return torch._C._nn._upsample_bilinear2d_aa(
643+
# input, output_size, align_corners, scale_factors
644+
# )
644645
return upsample_bilinear2d_op(
645646
input, output_size, scale_factors, align_corners
646647
)
@@ -867,7 +868,13 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_paddi
867868

868869
def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):
869870
if use_pyboost():
870-
return mint.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices)
871+
input_ndim = input.ndim
872+
if input_ndim == 3:
873+
input = input.unsqueeze(1)
874+
out = mint.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices)
875+
if input_ndim == 3:
876+
out = out.squeeze(1)
877+
return out
871878
return ops.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode, return_indices=return_indices)
872879

873880
def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False):

mindnlp/core/ops/inplace.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ def inplace_sub(self, other):
181181
self.data = core.sub(self, other)
182182
return self
183183

184+
def inplace_bernoulli(self, p=0.5, *, generator=None):
185+
self.data = core.bernoulli(self, generator=generator, p=p)
186+
return self
187+
184188
__all__ = [
185189
'inplace_copy',
186190
'inplace_zero',
@@ -202,5 +206,6 @@ def inplace_sub(self, other):
202206
'inplace_mul',
203207
'inplace_neg',
204208
'inplace_exp',
205-
'inplace_sub'
209+
'inplace_sub',
210+
'inplace_bernoulli'
206211
]

mindnlp/core/ops/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212

1313
# bernoulli
1414
has_bernoulli = hasattr(mindspore.mint, 'bernoulli')
15-
def bernoulli(input, *, generator=None, out=None):
15+
def bernoulli(input, *, generator=None, out=None, **kwargs):
16+
p = kwargs.pop('p', 0.5)
1617
if use_pyboost() and has_bernoulli:
1718
return call_ms_func(mindspore.mint.bernoulli, input, generator=generator, out=out)
1819
random_numbers = rand(*input.shape, dtype=mindspore.float32)
19-
samples = random_numbers < 0.5
20+
samples = random_numbers < p
2021
samples = samples.int()
2122
if out is None:
2223
return samples

mindnlp/core/ops/reduction.py

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""reduction op"""
2+
import numbers
23
from collections import namedtuple
34
import mindspore
45
from mindspore import ops
56
from mindspore.ops._primitive_cache import _get_cache_prim
67
from ..configs import use_pyboost, DEVICE_TARGET
78

89
from ._inner import call_ms_func
10+
from mindnlp import core
911

1012
max_out = namedtuple('max_out', ['values', 'indices'])
1113
min_out = namedtuple('min_out', ['values', 'indices'])
@@ -154,12 +156,180 @@ def prod(input, dim=None, keepdim=False, *, dtype=None):
154156
return ops.prod(input, dim, keepdim).to(dtype)
155157

156158
# quantile
157-
def quantile(input, q, dim=None, keepdim=False, *, interpolation='linear'):
158-
return ops.quantile(input, q, dim, keepdim)
159+
def quantile_output_shape(
160+
original_dim,
161+
input_tensor,
162+
q,
163+
keepdim,
164+
wrapped_dim
165+
):
166+
"""
167+
计算分位数函数的输出形状
168+
169+
参数:
170+
original_dim: 原始维度(None表示展平)
171+
input_tensor: 输入张量
172+
q: 分位数张量
173+
keepdim: 是否保留维度
174+
wrapped_dim: 处理后的维度索引
175+
"""
176+
# 计算输出形状: q大小 + 缩减维度后的大小
177+
out_shape = []
178+
179+
if original_dim is not None and input_tensor.dim() > 0:
180+
# 保留原始维度结构
181+
out_shape = list(input_tensor.shape)
182+
if keepdim:
183+
out_shape[wrapped_dim] = 1
184+
else:
185+
del out_shape[wrapped_dim]
186+
elif keepdim:
187+
# 当展平但需保留维度时创建全1形状
188+
out_shape = [1] * input_tensor.dim()
189+
190+
if q.dim() > 0:
191+
# 添加分位数维度到最前面
192+
out_shape.insert(0, q.numel())
193+
194+
return out_shape
195+
196+
197+
def quantile(
198+
input_tensor,
199+
q,
200+
dim = None,
201+
keepdim: bool = False,
202+
interpolation: str = 'linear',
203+
ignore_nan: bool = False
204+
):
205+
"""
206+
PyTorch分位数函数的完整实现
207+
208+
参数:
209+
input_tensor: 输入数据
210+
q: 分位数(0-1之间)
211+
dim: 计算维度
212+
keepdim: 是否保留维度
213+
interpolation: 插值模式 ('linear', 'lower', 'higher', 'nearest', 'midpoint')
214+
ignore_nan: 是否忽略NaN值
215+
216+
返回:
217+
计算得到的分位数
218+
"""
219+
if isinstance(q, numbers.Number):
220+
q = core.tensor(q, dtype=input_tensor.dtype)
221+
# ===== 1. 输入验证 =====
222+
device = input_tensor.device
223+
dtype = input_tensor.dtype
224+
225+
# 验证分位数范围
226+
if device.type == 'cpu':
227+
if not core.all((q >= 0) & (q <= 1)):
228+
raise ValueError("quantile() q values must be in the range [0, 1]")
229+
230+
# ===== 2. 维度处理 =====
231+
wrapped_dim = dim if dim is not None else 0
232+
original_dim = dim
233+
234+
if dim is not None:
235+
# 验证维度有效性
236+
if dim < 0:
237+
dim = input_tensor.dim() + dim
238+
if dim < 0 or dim >= input_tensor.dim():
239+
raise ValueError(f"Dimension out of range (expected to be in range [{-input_tensor.dim()}, {input_tensor.dim()-1}])")
240+
wrapped_dim = dim
241+
242+
# 计算输出形状
243+
out_shape = quantile_output_shape(original_dim, input_tensor, q, keepdim, wrapped_dim)
244+
245+
# ===== 3. 数据预处理 =====
246+
# 处理标量分位数
247+
q_scalar = q.dim() == 0
248+
q = q.reshape(-1) # 确保q是1D
249+
250+
# 展平或重排维度
251+
if dim is None:
252+
# 展平整个张量
253+
sorted_x, _ = input_tensor.flatten().sort()
254+
elif wrapped_dim == input_tensor.dim() - 1:
255+
# 当目标维度已是最后一维时直接排序
256+
sorted_x, _ = input_tensor.sort(dim=wrapped_dim)
257+
else:
258+
# 将目标维度移到末尾再排序
259+
transposed = input_tensor.transpose(wrapped_dim, -1).unsqueeze(-1)
260+
sorted_x, _ = transposed.sort(dim=-2)
261+
sorted_x = sorted_x.squeeze(-1)
262+
263+
# ===== 4. 分位数计算核心 =====
264+
n = sorted_x.shape[-1]
265+
266+
# 处理空输入
267+
if n == 0:
268+
result = core.full(out_shape, float('nan'), device=device, dtype=dtype)
269+
return result
270+
271+
# 计算排名位置 (考虑NaN处理)
272+
if ignore_nan:
273+
# 计算非NaN数量
274+
non_nan_count = (~sorted_x.isnan()).sum(dim=-1, keepdim=True)
275+
ranks = q * (non_nan_count - 1)
276+
ranks = core.clamp(ranks, min=0) # 防止负索引
277+
else:
278+
last_index = n - 1
279+
# 广播处理NaN标记
280+
nan_mask = sorted_x.isnan().any(dim=-1, keepdim=True)
281+
# 扩展q和nan_mask到相同形状
282+
expanded_q = q.view(1, -1).expand(*sorted_x.shape[:-1], q.numel())
283+
nan_mask = nan_mask.expand_as(expanded_q)
284+
# 计算基础排名
285+
ranks = expanded_q * last_index
286+
# 对包含NaN的行使用最后索引
287+
ranks = core.where(nan_mask, core.tensor(last_index, device=device), ranks)
288+
289+
# 根据插值模式调整排名
290+
if interpolation == 'lower':
291+
ranks = core.floor(ranks)
292+
elif interpolation == 'higher':
293+
ranks = core.ceil(ranks)
294+
elif interpolation == 'nearest':
295+
ranks = core.round(ranks)
296+
297+
# 确保排名在有效范围内
298+
ranks = core.clamp(ranks, 0, n - 1)
299+
300+
# 获取下界索引和值
301+
ranks_below = ranks.to(core.int64)
302+
values_below = sorted_x.gather(-1, ranks_below)
303+
304+
# ===== 5. 插值处理 =====
305+
if interpolation in ['linear', 'midpoint']:
306+
# 计算插值权重
307+
weights = core.full_like(ranks, 0.5) if interpolation == 'midpoint' else ranks - ranks_below
308+
309+
# 获取上界值
310+
ranks_above = core.ceil(ranks).to(core.int64)
311+
values_above = sorted_x.gather(-1, ranks_above)
312+
313+
# 线性插值: result = (1 - weight)*below + weight*above
314+
values_below = values_below.lerp(values_above, weights)
315+
316+
# ===== 6. 形状调整 =====
317+
if q_scalar:
318+
# 标量分位数:移除分位数维度
319+
values_below = values_below.squeeze(-1)
320+
else:
321+
# 多分位数:移动分位数维度到最前面
322+
values_below = values_below.movedim(-1, 0)
323+
324+
# 恢复原始输出形状
325+
if values_below.shape != tuple(out_shape):
326+
values_below = values_below.reshape(out_shape)
327+
328+
return values_below
159329

160330
# nanquantile
161331
def nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear'):
162-
return ops.quantile(input, q, dim, keepdim)
332+
return ops.nanquantile(input, q, dim, keepdim)
163333

164334
# std
165335
has_std = hasattr(mindspore.mint, 'std')

0 commit comments

Comments
 (0)