diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index 3cef7d9a8..d085a25e8 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -38,7 +38,7 @@ mindspore.set_device(os.environ.get('DEVICE_TARGET')) # for different ascend devices -if platform.system().lower() == 'linux': +if platform.system().lower() == 'linux' and mindspore.get_context('device_target') == 'Ascend': SOC = MSContext.get_instance().get_ascend_soc_version() # enable vmm since only vmm can release device memory when del tensor. if SOC != 'ascend310b': diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index e1e3fe26f..078935bae 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -36,6 +36,7 @@ preserve_format = None legacy_contiguous_format = None channels_last_3d = None +channels_last = None memory_format = None inf = float("inf") diff --git a/mindnlp/core/_apis/cpu.py b/mindnlp/core/_apis/cpu.py index 610188b67..290dae33b 100644 --- a/mindnlp/core/_apis/cpu.py +++ b/mindnlp/core/_apis/cpu.py @@ -1221,3 +1221,12 @@ def logsumexp(input, dim, keepdim=False): def bernoulli(input, generator): return legacy.bernoulli(input, seed, offset) + +def right_shift(input, other): + return legacy.right_shift(input, other) + +def histc(input, bins=100, min=0, max=0): + return legacy.histogram(input, bins, float(min), float(max)) + +def search_sorted(sorted_sequence, values, sorter, dtype, right): + return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right) \ No newline at end of file diff --git a/mindnlp/core/_apis/gpu.py b/mindnlp/core/_apis/gpu.py index 914897e6d..d38aa3478 100644 --- a/mindnlp/core/_apis/gpu.py +++ b/mindnlp/core/_apis/gpu.py @@ -4,7 +4,7 @@ import mindspore from mindspore._c_expression import _empty_instance from mindnlp import core -from .._op_prim.cpu import legacy +from .._op_prim.gpu import legacy try: from mindspore._c_expression import TensorPy as Tensor_ @@ -34,6 +34,8 @@ def fill_scalar(size, fill_value, dtype): return legacy.cast(legacy.fill_v2(size, mindspore.Tensor(fill_value)), dtype) def fill_tensor(size, fill_value, dtype): + if dtype is None: + return legacy.fill_v2(size, mindspore.Tensor(fill_value)) return legacy.cast(legacy.fill_v2(size, fill_value), dtype) def zeros_like(input, dtype): @@ -123,6 +125,9 @@ def div(input, other): return legacy.div(input, other) def mul(input, other): + if input.dtype == core.bool: + if isinstance(other, bool) or (not isinstance(other, numbers.Number) and other.dtype == core.bool): + return bitwise_and_scalar(input, other) return legacy.mul(input, other) def reduce_all(input, axis, keepdims): @@ -253,6 +258,11 @@ def less(input, other): return legacy.less(input, other) def select(condition, x, y): + if isinstance(x, numbers.Number) or x.ndim == 0: + x = fill_scalar(condition.shape, x, None) + if isinstance(y, numbers.Number) or y.ndim == 0: + y = fill_scalar(condition.shape, y, None) + return legacy.select(condition, x, y) def round(input, decimals): @@ -317,7 +327,7 @@ def ones_like(input, dtype): return legacy.ones_like(input) def embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq): - return cast(legacy.gather(weight, input, 0, 0), weight.dtype) + return legacy.gather(weight, input, 0, 0) def linspace(start, end, steps, dtype): start = float(start) @@ -325,8 +335,7 @@ def linspace(start, end, steps, dtype): return legacy.lin_space(mindspore.Tensor(start), mindspore.Tensor(end), steps) def masked_fill(input, mask, value): - if input.dtype.is_floating_point and isinstance(value, numbers.Number): - value = float(value) + value = fill_scalar((), value, input.dtype) return legacy.masked_fill(input, mask, value) def sum(input, dim, keepdim, dtype): @@ -388,9 +397,14 @@ def layer_norm(input, normalized_shape, weight, bias, eps=1e-5): return legacy.layer_norm(input, weight, bias, begin_axis, begin_axis, eps) def argmin_with_value(input, axis, keep_dims): + if axis is None: + axis = -1 return legacy.arg_min_with_value(input, axis, keep_dims) def argmax_with_value(input, axis, keep_dims): + if axis is None: + axis = -1 + return legacy.arg_max_with_value(input, axis, keep_dims) def silu(input): @@ -425,9 +439,13 @@ def eye(n, m, dtype): return legacy.eye(n, m, dtype) def argmax(input, axis, keep_dims): + if axis is None: + axis = -1 return legacy.arg_max_with_value(input, axis, keep_dims)[0] def argmin(input, axis, keep_dims): + if axis is None: + axis = -1 return legacy.arg_min_with_value(input, axis, keep_dims)[0] def exp(input): @@ -489,18 +507,7 @@ def scatter(input, dim, index, src): return legacy.tensor_scatter_elements(input, index, src, dim, "none") def batch_norm(input, weight, bias, running_mean=None, runnning_var=None, training=False, momentum=0.1, epsilon=1e-5): - input_ndim = input.ndim - if input_ndim == 2: - return legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW') - else: - input = transpose_view(input, 1, -1) - input_shape = input.shape - input = reshape(input, (-1, input.shape[-1])) - outs = legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW') - out = reshape(outs[0], (*input_shape[:-1], -1)) - out = transpose_view(out, 1, -1) - - return out, outs[1], outs[2] + return legacy.batch_norm(input, weight, bias, running_mean, runnning_var, training, epsilon, momentum, 'NCHW') def tanh(input): return legacy.tanh(input) @@ -797,7 +804,7 @@ def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=Fa return out def baddbmm(input, batch1, batch2, alpha=1, beta=1): - return add(mul(beta, input), mul(alpha, bmm(batch1, batch2))) + return add(mul(input, beta), mul(bmm(batch1, batch2), alpha)) def softplus(input, beta=1, threshold=20): return legacy.softplus(input) @@ -805,9 +812,6 @@ def softplus(input, beta=1, threshold=20): def gather_nd(input, indices): return legacy.gather_nd(input, indices) -def unique_consecutive(input, return_inverse, return_counts, dim): - return legacy.unique_consecutive(input, return_inverse, return_counts, dim) - def meshgrid(input, lambd): return legacy.meshgrid(input, lambd) @@ -815,7 +819,7 @@ def addcmul(input, tensor1, tensor2, value=1.0): return legacy.addcmul(input, tensor1, tensor2, mindspore.Tensor(value)) def addmm(input, mat1, mat2, alpha=1.0, beta=1.0): - return add(mul(beta, input), mul(alpha, bmm(mat1, mat2))) + return add(mul(input, beta), mul(bmm(mat1, mat2), alpha)) def im2col(input, kernel_size, dilation=1, padding=0, stride=1): out = legacy.im2_col(input, kernel_size, stride, dilation, padding) @@ -1101,6 +1105,8 @@ def bernoulli(input, generator): return legacy.bernoulli(input, seed, offset) def arange(start, end, step, dtype): + if dtype is not None: + return cast(legacy.range(start, end, step, 100000), dtype) return legacy.range(start, end, step, 100000) def inplace_fill_scalar(input, value): @@ -1121,3 +1127,13 @@ def inplace_uniform(input, from_, to_, generator_): mindspore.tensor(from_, dtype=mindspore.int32), mindspore.tensor(to_, dtype=mindspore.int32), 0, 0) return input.assign_value(value) + +def right_shift(input, other): + return legacy.right_shift(input, other) + +def inplace_fill_tensor(input, value): + input.assign_value(fill_tensor(input.shape, value, None)) + return input + +def search_sorted(sorted_sequence, values, sorter, dtype, right): + return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right) \ No newline at end of file diff --git a/mindnlp/core/_apis/npu.py b/mindnlp/core/_apis/npu.py index c5d463a3e..b8f92b38d 100644 --- a/mindnlp/core/_apis/npu.py +++ b/mindnlp/core/_apis/npu.py @@ -1594,3 +1594,13 @@ def bernoulli(input, generator): def multinomial(input, num_samples, replacement, generator): seed, offset = generator._step(12) # pylint: disable=protected-access return pyboost.multinomial_ext_op(input, num_samples, replacement, seed, offset) + +def right_shift(input, other): + if use_pyboost(): + return pyboost.right_shift_op(input, other) + return legacy.right_shift(input, other) + +def histc(input, bins=100, min=0, max=0): + if use_pyboost(): + return pyboost.histc_ext_op(input, bins, float(min), float(max)) + return legacy.histogram(input, bins, float(min), float(max)) diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 7b30d19cc..3656c52f7 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -110,6 +110,7 @@ def __init__(self, *args, **kwargs): Tensor.__init__ = __init__ origin_setitem = Tensor.__setitem__ +origin_is_contiguous = Tensor.is_contiguous Tensor._requires_grad = False def tensor(data, *, dtype=None, device=None, requires_grad=False): @@ -1253,7 +1254,8 @@ def hardshrink(self, lambd=0.5): # Tensor.histc - + def histc(self, bins=100, min=0, max=0): + return ops.histc(self, bins, min, max) # Tensor.histogram @@ -1364,8 +1366,8 @@ def isnan(self): return ops.isnan(self) # Tensor.is_contiguous - # def is_contiguous(self): - # return self.is_contiguous() + def is_contiguous(self, memory_format=None): + return origin_is_contiguous(self) # Tensor.is_complex def is_complex(self): diff --git a/mindnlp/core/cuda/__init__.py b/mindnlp/core/cuda/__init__.py index 82e69ae30..59e1c6c97 100644 --- a/mindnlp/core/cuda/__init__.py +++ b/mindnlp/core/cuda/__init__.py @@ -60,8 +60,19 @@ def __exit__(self, type: Any, value: Any, traceback: Any): def is_bf16_supported(): return False -def mem_get_info(index): - return (1024, 1024) +def mem_get_info(device=None): + if not isinstance(device, int): + device = mindspore.context.get_context("device_id") + + res = mindspore.hal.get_device_properties(device) + return (res.total_memory, res.total_memory) + +def get_device_capability(device=None): + if not isinstance(device, int): + device = mindspore.context.get_context("device_id") + + res = mindspore.hal.get_device_properties(device) + return (res.major, res.minor) def memory_reserved(device=None): return ms_memory_reserved() diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 80461f421..634612fe7 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -274,7 +274,7 @@ def pad(input, pad, mode='constant', value=None): if isinstance(pad, tuple): pad = tuple(p if isinstance(p, int) else p.item() for p in pad) - if input.device.type in ['cpu', 'meta'] or ON_A1: + if input.device.type in ['cpu', 'meta', 'cuda'] or ON_A1: new_pad = () for idx, pad_v in enumerate(pad): if not isinstance(pad_v, int): @@ -301,6 +301,8 @@ def pad(input, pad, mode='constant', value=None): value = bool(value) elif input.dtype in [core.int32, core.int64]: value = int(value) + if input.device.type == 'cuda' and len(new_pad) == 8: + return execute('pad_v3', input, new_pad[:-2], mode, value) return execute('pad_v3', input, new_pad, mode, value) out = input if (isinstance(pad, tuple) and not pad): @@ -324,9 +326,9 @@ def pad(input, pad, mode='constant', value=None): return out def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean'): - # if input.device.type == 'npu': - return _nllloss_nd(input, target, weight, ignore_index, reduction) - # return _inner_nll_loss(input, target, weight, ignore_index, reduction) + if input.device.type in ['npu', 'cpu']: + return _nllloss_nd(input, target, weight, ignore_index, reduction) + return _inner_nll_loss(input, target, weight, ignore_index, reduction) def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): ndim = inputs.ndim @@ -352,7 +354,7 @@ def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='m def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0): """nll loss inner function""" if target.ndim == inputs.ndim - 1: - target = target.expand_dims(target_dim) + target = target.unsqueeze(target_dim) if ignore_index is not None: non_pad_mask = core.eq(target, ignore_index) target = target.masked_fill(non_pad_mask, core.cast(0, target.dtype)) @@ -366,10 +368,10 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red weight = weight.view(weight.shape + (1,)) weighted_inputs = inputs * weight weighted_inputs = weighted_inputs.view(orig_shape) - loss = core.neg(core.gather_d(weighted_inputs, target_dim, target)) + loss = core.neg(core.gather(weighted_inputs, target_dim, target)) smooth_loss = core.neg(weighted_inputs.sum(axis=target_dim, keepdims=True)) else: - loss = core.neg(core.gather_d(inputs, target_dim, target)) + loss = core.neg(core.gather(inputs, target_dim, target)) smooth_loss = core.neg(inputs.sum(axis=target_dim, keepdims=True)) loss_weights = core.ones_like(loss) @@ -427,11 +429,42 @@ def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean') ret = execute('nllloss_2d', input, target, weight, reduction, ingore_index)[0] return ret.view(out_size) + +def cross_entropy_gpu(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + class_dim = 0 if input.ndim == 1 else 1 + if target.dtype.is_floating_point: + return _cross_entropy(input, target, class_dim, weight, reduction, label_smoothing) + return nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction) + +def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', label_smoothing=0.0): + """cross entropy inner function""" + class_dim = 0 if inputs.ndim == 1 else 1 + n_classes = inputs.shape[class_dim] + inputs = log_softmax(inputs, class_dim) + if label_smoothing > 0.0: + target = target * (1 - label_smoothing) + label_smoothing / n_classes + + if weight is None: + weight = core.ones_like(inputs) + elif inputs.ndim != 1: + broadcast_shape = [1 for _ in range(inputs.ndim)] + broadcast_shape[1] = weight.shape[0] + weight = weight.reshape(broadcast_shape) + + if reduction == 'mean': + return -(inputs * target * weight).sum() / (inputs.nel / n_classes) + if reduction == 'sum': + return -(inputs * target * weight).sum() + return -(inputs * target * weight).sum(class_dim) + + def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): if label_smoothing < 0.0 or label_smoothing > 1.0: raise ValueError(f"For cross_entropy, label_smoothing must in [0, 1]") if input.ndim == 0 or input.shape[0] == 0: raise ValueError(f"For cross_entropy, input don't support 0-dim and shape[0].") + if input.device.type == 'cuda': + return cross_entropy_gpu(input, target, weight, ignore_index, reduction, label_smoothing) class_dim = 0 if input.ndim == 1 else 1 n_classes = input.shape[class_dim] input = log_softmax(input, class_dim, dtype=input.dtype) @@ -675,10 +708,10 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne ) if input.dim() == 4 and mode == "bicubic": assert align_corners is not None - if antialias: - return torch._C._nn._upsample_bicubic2d_aa( - input, output_size, align_corners, scale_factors - ) + # if antialias: + # return torch._C._nn._upsample_bicubic2d_aa( + # input, output_size, align_corners, scale_factors + # ) return execute( 'upsample_bicubic2d', input, output_size, scale_factors, align_corners ) @@ -1146,8 +1179,8 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. else: attn_bias = attn_mask + attn_bias - attn_weight = query.float() @ key.transpose(-2, -1).float() * scale_factor - attn_weight += attn_bias.float() + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias attn_weight = softmax(attn_weight, dim=-1, dtype=core.float32).to(query.dtype) attn_weight = dropout(attn_weight, dropout_p, training=True) return attn_weight @ value diff --git a/mindnlp/core/ops/_inner.py b/mindnlp/core/ops/_inner.py index 2f7c35915..71f9c777c 100644 --- a/mindnlp/core/ops/_inner.py +++ b/mindnlp/core/ops/_inner.py @@ -16,7 +16,14 @@ def npu_clear_float_status_v2(status): def all_finite(inputs): return execute('all_finite', inputs) +def custom_masked_scatter_vec(input, mask, source): + output = input.clone() + output[mask] = source.flatten() # 关键的一行:向量化赋值 + return output + def masked_scatter(input, mask, source): + if input.device.type == 'cuda': + return custom_masked_scatter_vec(input, mask, source) return execute('masked_scatter', input, mask, source) __all__ = [ diff --git a/mindnlp/core/ops/array.py b/mindnlp/core/ops/array.py index 97f60f6e6..3ece443d4 100644 --- a/mindnlp/core/ops/array.py +++ b/mindnlp/core/ops/array.py @@ -920,7 +920,7 @@ def _tensordot(a, b): b = broadcast_to(b, a.shape) return core.sum(a * b, dim=-1) - stacked_indices = _tensordot(stacked_indices, core.tensor(index_scaling)) + stacked_indices = _tensordot(stacked_indices, core.tensor(index_scaling).to(stacked_indices.device)) flat_shape = shape_tensor[:axis] + (-1,) + shape_tensor[axis + len(dims) :] tensor = tensor.reshape(flat_shape) diff --git a/mindnlp/core/ops/other.py b/mindnlp/core/ops/other.py index 076c1b9b6..ae2e1e245 100644 --- a/mindnlp/core/ops/other.py +++ b/mindnlp/core/ops/other.py @@ -669,7 +669,40 @@ def flip(input, dims): # histc - +def manual_histc_searchsorted(input_tensor, bins=100, min=0, max=0): + """ + 使用 searchsorted 实现 histc,适用于浮点数,更精确地模拟边界。 + """ + if min == 0 and max == 0: + min = input_tensor.min().item() + max = input_tensor.max().item() + + bin_width = (max - min) / bins + # 生成 bin 的右边界(除了最后一个,因为histc的最后一个bin是闭区间[2,5]) + bin_edges = core.linspace(min, max, bins + 1, device=input_tensor.device) + # 调整最后一个区间的右边界为无穷大,以确保等于max的值被包含在最后一个bin + # 同时,其他区间保持左闭右开 + bin_edges[-1] = float('inf') + + flattened = input_tensor.view(-1) + # 找到每个元素应该插入到 bin_edges 中的位置,然后减1得到 bin 索引 + # side='right' 表示返回的是使得 sorted_sequence[i-1] < v <= sorted_sequence[i] 成立的索引 i + indices = core.searchsorted(bin_edges, flattened, side='right') - 1 + + # 处理小于 min 的值(索引会变成 -1) + valid_mask = (indices >= 0) + indices_valid = indices[valid_mask] + # 同样需要确保索引不超过 bins-1(理论上由于bin_edges[-1]=inf,不会超过,但保险起见) + indices_valid = core.clamp(indices_valid, 0, bins - 1) + + # 使用 bincount 统计有效的索引 + histogram = core.bincount(indices_valid, minlength=bins) + return histogram.float() # 保持与 histc 输出类型一致 + +def histc(input, bins=100, min=0, max=0): + if input.device.type == 'cuda': + return manual_histc_searchsorted(input, bins, min, max) + return execute('histc', input, bins, min, max) # histogram @@ -1077,5 +1110,6 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8): 'view_as_real', 'bucketize', 'cosine_similarity', - 'detach' + 'detach', + 'histc' ] diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 39b9ae88d..3ca127784 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -8,7 +8,8 @@ min_out = namedtuple('min_out', ['values', 'indices']) # argmax -def argmax(input, dim=None, keepdim=False): +def argmax(input, dim=None, keepdim=False, **kwargs): + dim = kwargs.pop('axis', dim) return execute('argmax', input, dim, keepdim) # argmin @@ -192,8 +193,165 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No return y, counts return y -# unique_consecutive +def unique_consecutive_optimized(input, return_inverse=False, return_counts=False, dim=None): + """ + 优化版的 torch.unique_consecutive 手动实现。 + """ + if dim is None: + input_flat = input.flatten() + return _unique_consecutive_1d(input_flat, return_inverse, return_counts) + else: + return _unique_consecutive_nd(input, dim, return_inverse, return_counts) + +def _unique_consecutive_1d(input, return_inverse, return_counts): + """处理一维张量的优化实现""" + if input.numel() == 0: + return _handle_empty_input(input, return_inverse, return_counts) + + # 找到变化点的位置 + diff = input[1:] != input[:-1] + change_points = core.cat([ + core.tensor([True], device=input.device), + diff, + core.tensor([True], device=input.device) + ]) + change_indices = core.where(change_points)[0] + + # 提取唯一值 + unique_values = input[change_indices[:-1]] + + # 准备返回结果 + result = [unique_values] + + # 处理逆向索引 + if return_inverse: + inverse_indices = core.repeat_interleave( + core.arange(len(unique_values), device=input.device), + core.diff(change_indices) + ) + result.append(inverse_indices) + + # 处理计数 + if return_counts: + counts = core.diff(change_indices) + result.append(counts) + + return result[0] if len(result) == 1 else tuple(result) + +def _unique_consecutive_nd(input, dim, return_inverse, return_counts): + """处理多维张量的实现""" + # 将目标维度移动到最后一维 + input_transposed = input.transpose(dim, -1) + original_shape = input_transposed.shape + input_2d = input_transposed.reshape(-1, original_shape[-1]) + + results = [] + for i in range(input_2d.shape[0]): + slice_result = _unique_consecutive_1d(input_2d[i], return_inverse, return_counts) + if isinstance(slice_result, tuple): + results.append(slice_result) + else: + results.append((slice_result,)) + + # 重组结果 + return _reconstruct_nd_results(results, original_shape, dim, return_inverse, return_counts) + +def _handle_empty_input(input, return_inverse, return_counts): + """处理空输入的情况""" + empty_tensor = core.tensor([], dtype=input.dtype, device=input.device) + if return_inverse and return_counts: + return empty_tensor, core.tensor([], dtype=core.long, device=input.device), core.tensor([], dtype=core.long, device=input.device) + elif return_inverse: + return empty_tensor, core.tensor([], dtype=core.long, device=input.device) + elif return_counts: + return empty_tensor, core.tensor([], dtype=core.long, device=input.device) + else: + return empty_tensor + +def _reconstruct_nd_results(results, original_shape, dim, return_inverse, return_counts): + """ + 重组多维处理结果 + """ + # 确定最大唯一值长度(用于填充) + max_unique_len = max(len(result[0]) for result in results) + batch_size = original_shape[0] # 第一维的大小(其他维度的乘积) + + # 重组唯一值张量 + unique_dtype = results[0][0].dtype + unique_device = results[0][0].device + + # 创建输出唯一值张量 + unique_output_shape = list(original_shape) + unique_output_shape[-1] = max_unique_len + unique_output = core.full(unique_output_shape, 0, dtype=unique_dtype, device=unique_device) + + # 填充唯一值张量 + for i, result in enumerate(results): + unique_slice = result[0] + unique_output[i, :len(unique_slice)] = unique_slice + + # 重塑回原始形状(不包括被处理的维度) + final_unique_shape = list(original_shape[:-1]) + [max_unique_len] + unique_output = unique_output.reshape(final_unique_shape) + + # 如果需要恢复原始维度顺序 + if dim != -1: + # 计算原始维度顺序 + dim_perm = list(range(unique_output.dim())) + # 将最后一个维度移回原始位置 + dim_perm.append(dim_perm.pop(-1)) + # 调整维度顺序 + unique_output = unique_output.permute(dim_perm) + + # 处理返回结果 + output_results = [unique_output] + + # 处理逆向索引 + if return_inverse: + inverse_shape = list(original_shape) + inverse_output = core.zeros(inverse_shape, dtype=core.long, device=unique_device) + + for i, result in enumerate(results): + if len(result) > 1: # 确保有逆向索引 + inverse_slice = result[1] + inverse_output[i, :len(inverse_slice)] = inverse_slice + + # 重塑逆向索引到原始形状 + inverse_output = inverse_output.reshape(original_shape[:-1] + [original_shape[-1]]) + + # 调整维度顺序 + if dim != -1: + inverse_output = inverse_output.permute(dim_perm) + + output_results.append(inverse_output) + + # 处理计数 + if return_counts: + counts_shape = list(original_shape[:-1]) + [max_unique_len] + counts_output = core.zeros(counts_shape, dtype=core.long, device=unique_device) + + for i, result in enumerate(results): + counts_index = 2 if return_inverse else 1 # 确定计数在结果中的位置 + if len(result) > counts_index: + counts_slice = result[counts_index] + counts_output[i, :len(counts_slice)] = counts_slice + + # 调整计数张量的维度顺序 + if dim != -1: + counts_output = counts_output.permute(dim_perm) + + output_results.append(counts_output) + + # 返回适当的结果组合 + if len(output_results) == 1: + return output_results[0] + else: + return tuple(output_results) + + def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None): + if input.device.type == 'cuda': + return unique_consecutive_optimized(input, return_inverse, return_counts, dim) output, idx, counts = execute('unique_consecutive', input, return_inverse, return_counts, dim) if return_inverse and return_counts: return output, idx, counts