|
1 | 1 | """reduction op"""
|
| 2 | +import numbers |
2 | 3 | from collections import namedtuple
|
3 | 4 | import mindspore
|
4 | 5 | from mindspore import ops
|
5 | 6 | from mindspore.ops._primitive_cache import _get_cache_prim
|
6 | 7 | from ..configs import use_pyboost, DEVICE_TARGET
|
7 | 8 |
|
8 | 9 | from ._inner import call_ms_func
|
| 10 | +from mindnlp import core |
9 | 11 |
|
10 | 12 | max_out = namedtuple('max_out', ['values', 'indices'])
|
11 | 13 | min_out = namedtuple('min_out', ['values', 'indices'])
|
@@ -154,12 +156,180 @@ def prod(input, dim=None, keepdim=False, *, dtype=None):
|
154 | 156 | return ops.prod(input, dim, keepdim).to(dtype)
|
155 | 157 |
|
156 | 158 | # quantile
|
157 |
| -def quantile(input, q, dim=None, keepdim=False, *, interpolation='linear'): |
158 |
| - return ops.quantile(input, q, dim, keepdim) |
| 159 | +def quantile_output_shape( |
| 160 | + original_dim, |
| 161 | + input_tensor, |
| 162 | + q, |
| 163 | + keepdim, |
| 164 | + wrapped_dim |
| 165 | +): |
| 166 | + """ |
| 167 | + 计算分位数函数的输出形状 |
| 168 | + |
| 169 | + 参数: |
| 170 | + original_dim: 原始维度(None表示展平) |
| 171 | + input_tensor: 输入张量 |
| 172 | + q: 分位数张量 |
| 173 | + keepdim: 是否保留维度 |
| 174 | + wrapped_dim: 处理后的维度索引 |
| 175 | + """ |
| 176 | + # 计算输出形状: q大小 + 缩减维度后的大小 |
| 177 | + out_shape = [] |
| 178 | + |
| 179 | + if original_dim is not None and input_tensor.dim() > 0: |
| 180 | + # 保留原始维度结构 |
| 181 | + out_shape = list(input_tensor.shape) |
| 182 | + if keepdim: |
| 183 | + out_shape[wrapped_dim] = 1 |
| 184 | + else: |
| 185 | + del out_shape[wrapped_dim] |
| 186 | + elif keepdim: |
| 187 | + # 当展平但需保留维度时创建全1形状 |
| 188 | + out_shape = [1] * input_tensor.dim() |
| 189 | + |
| 190 | + if q.dim() > 0: |
| 191 | + # 添加分位数维度到最前面 |
| 192 | + out_shape.insert(0, q.numel()) |
| 193 | + |
| 194 | + return out_shape |
| 195 | + |
| 196 | + |
| 197 | +def quantile( |
| 198 | + input_tensor, |
| 199 | + q, |
| 200 | + dim = None, |
| 201 | + keepdim: bool = False, |
| 202 | + interpolation: str = 'linear', |
| 203 | + ignore_nan: bool = False |
| 204 | +): |
| 205 | + """ |
| 206 | + PyTorch分位数函数的完整实现 |
| 207 | + |
| 208 | + 参数: |
| 209 | + input_tensor: 输入数据 |
| 210 | + q: 分位数(0-1之间) |
| 211 | + dim: 计算维度 |
| 212 | + keepdim: 是否保留维度 |
| 213 | + interpolation: 插值模式 ('linear', 'lower', 'higher', 'nearest', 'midpoint') |
| 214 | + ignore_nan: 是否忽略NaN值 |
| 215 | + |
| 216 | + 返回: |
| 217 | + 计算得到的分位数 |
| 218 | + """ |
| 219 | + if isinstance(q, numbers.Number): |
| 220 | + q = core.tensor(q, dtype=input_tensor.dtype) |
| 221 | + # ===== 1. 输入验证 ===== |
| 222 | + device = input_tensor.device |
| 223 | + dtype = input_tensor.dtype |
| 224 | + |
| 225 | + # 验证分位数范围 |
| 226 | + if device.type == 'cpu': |
| 227 | + if not core.all((q >= 0) & (q <= 1)): |
| 228 | + raise ValueError("quantile() q values must be in the range [0, 1]") |
| 229 | + |
| 230 | + # ===== 2. 维度处理 ===== |
| 231 | + wrapped_dim = dim if dim is not None else 0 |
| 232 | + original_dim = dim |
| 233 | + |
| 234 | + if dim is not None: |
| 235 | + # 验证维度有效性 |
| 236 | + if dim < 0: |
| 237 | + dim = input_tensor.dim() + dim |
| 238 | + if dim < 0 or dim >= input_tensor.dim(): |
| 239 | + raise ValueError(f"Dimension out of range (expected to be in range [{-input_tensor.dim()}, {input_tensor.dim()-1}])") |
| 240 | + wrapped_dim = dim |
| 241 | + |
| 242 | + # 计算输出形状 |
| 243 | + out_shape = quantile_output_shape(original_dim, input_tensor, q, keepdim, wrapped_dim) |
| 244 | + |
| 245 | + # ===== 3. 数据预处理 ===== |
| 246 | + # 处理标量分位数 |
| 247 | + q_scalar = q.dim() == 0 |
| 248 | + q = q.reshape(-1) # 确保q是1D |
| 249 | + |
| 250 | + # 展平或重排维度 |
| 251 | + if dim is None: |
| 252 | + # 展平整个张量 |
| 253 | + sorted_x, _ = input_tensor.flatten().sort() |
| 254 | + elif wrapped_dim == input_tensor.dim() - 1: |
| 255 | + # 当目标维度已是最后一维时直接排序 |
| 256 | + sorted_x, _ = input_tensor.sort(dim=wrapped_dim) |
| 257 | + else: |
| 258 | + # 将目标维度移到末尾再排序 |
| 259 | + transposed = input_tensor.transpose(wrapped_dim, -1).unsqueeze(-1) |
| 260 | + sorted_x, _ = transposed.sort(dim=-2) |
| 261 | + sorted_x = sorted_x.squeeze(-1) |
| 262 | + |
| 263 | + # ===== 4. 分位数计算核心 ===== |
| 264 | + n = sorted_x.shape[-1] |
| 265 | + |
| 266 | + # 处理空输入 |
| 267 | + if n == 0: |
| 268 | + result = core.full(out_shape, float('nan'), device=device, dtype=dtype) |
| 269 | + return result |
| 270 | + |
| 271 | + # 计算排名位置 (考虑NaN处理) |
| 272 | + if ignore_nan: |
| 273 | + # 计算非NaN数量 |
| 274 | + non_nan_count = (~sorted_x.isnan()).sum(dim=-1, keepdim=True) |
| 275 | + ranks = q * (non_nan_count - 1) |
| 276 | + ranks = core.clamp(ranks, min=0) # 防止负索引 |
| 277 | + else: |
| 278 | + last_index = n - 1 |
| 279 | + # 广播处理NaN标记 |
| 280 | + nan_mask = sorted_x.isnan().any(dim=-1, keepdim=True) |
| 281 | + # 扩展q和nan_mask到相同形状 |
| 282 | + expanded_q = q.view(1, -1).expand(*sorted_x.shape[:-1], q.numel()) |
| 283 | + nan_mask = nan_mask.expand_as(expanded_q) |
| 284 | + # 计算基础排名 |
| 285 | + ranks = expanded_q * last_index |
| 286 | + # 对包含NaN的行使用最后索引 |
| 287 | + ranks = core.where(nan_mask, core.tensor(last_index, device=device), ranks) |
| 288 | + |
| 289 | + # 根据插值模式调整排名 |
| 290 | + if interpolation == 'lower': |
| 291 | + ranks = core.floor(ranks) |
| 292 | + elif interpolation == 'higher': |
| 293 | + ranks = core.ceil(ranks) |
| 294 | + elif interpolation == 'nearest': |
| 295 | + ranks = core.round(ranks) |
| 296 | + |
| 297 | + # 确保排名在有效范围内 |
| 298 | + ranks = core.clamp(ranks, 0, n - 1) |
| 299 | + |
| 300 | + # 获取下界索引和值 |
| 301 | + ranks_below = ranks.to(core.int64) |
| 302 | + values_below = sorted_x.gather(-1, ranks_below) |
| 303 | + |
| 304 | + # ===== 5. 插值处理 ===== |
| 305 | + if interpolation in ['linear', 'midpoint']: |
| 306 | + # 计算插值权重 |
| 307 | + weights = core.full_like(ranks, 0.5) if interpolation == 'midpoint' else ranks - ranks_below |
| 308 | + |
| 309 | + # 获取上界值 |
| 310 | + ranks_above = core.ceil(ranks).to(core.int64) |
| 311 | + values_above = sorted_x.gather(-1, ranks_above) |
| 312 | + |
| 313 | + # 线性插值: result = (1 - weight)*below + weight*above |
| 314 | + values_below = values_below.lerp(values_above, weights) |
| 315 | + |
| 316 | + # ===== 6. 形状调整 ===== |
| 317 | + if q_scalar: |
| 318 | + # 标量分位数:移除分位数维度 |
| 319 | + values_below = values_below.squeeze(-1) |
| 320 | + else: |
| 321 | + # 多分位数:移动分位数维度到最前面 |
| 322 | + values_below = values_below.movedim(-1, 0) |
| 323 | + |
| 324 | + # 恢复原始输出形状 |
| 325 | + if values_below.shape != tuple(out_shape): |
| 326 | + values_below = values_below.reshape(out_shape) |
| 327 | + |
| 328 | + return values_below |
159 | 329 |
|
160 | 330 | # nanquantile
|
161 | 331 | def nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear'):
|
162 |
| - return ops.quantile(input, q, dim, keepdim) |
| 332 | + return ops.nanquantile(input, q, dim, keepdim) |
163 | 333 |
|
164 | 334 | # std
|
165 | 335 | has_std = hasattr(mindspore.mint, 'std')
|
|
0 commit comments