Skip to content

Commit 7590515

Browse files
[MLIR][TORCH] Add dtype arg support for torch.cumsum op (#4155)
This commit adds the support for dtype argument of torch.cumsum op which specifies the desired data type of returned tensor. Fixes #3866. --------- Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 9e91403 commit 7590515

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,14 +1634,35 @@ class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
16341634
auto resultType = cast<RankedTensorType>(
16351635
getTypeConverter()->convertType(op->getResult(0).getType()));
16361636
Type elementType = resultType.getElementType();
1637-
Type inputElementType =
1638-
cast<RankedTensorType>(input.getType()).getElementType();
1637+
auto inputType = cast<RankedTensorType>(input.getType());
1638+
Type inputElementType = inputType.getElementType();
16391639

1640-
// Converting the input element type to the result's element type.
1641-
// The only possible mismatch would be when the input element type is an
1642-
// integer but not `si64`. Therefore, we directly convert the input to
1643-
// `si64`. Rest all cases are handled in the dtype definition for this op.
1644-
if (elementType != inputElementType) {
1640+
Value dtype = op.getDtype();
1641+
if (!isa<Torch::NoneType>(dtype.getType())) {
1642+
int64_t dtypeInt;
1643+
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
1644+
return rewriter.notifyMatchFailure(
1645+
op, "unimplemented: only constant int dtype value is supported");
1646+
1647+
FailureOr<Type> resDtype = getTypeForScalarType(
1648+
op->getContext(), (torch_upstream::ScalarType)dtypeInt);
1649+
if (failed(resDtype))
1650+
return rewriter.notifyMatchFailure(
1651+
op, "unsupported: dtype not defined for the given dtype int value");
1652+
1653+
Value torchInput =
1654+
convertTensorToDtype(rewriter, loc, op.getSelf(), resDtype.value());
1655+
input = typeConverter->materializeTargetConversion(
1656+
rewriter, loc, typeConverter->convertType(torchInput.getType()),
1657+
torchInput);
1658+
} else if (elementType != inputElementType &&
1659+
isa<mlir::IntegerType>(elementType) &&
1660+
isa<mlir::IntegerType>(inputElementType)) {
1661+
// Converting the input element type to the result's element type.
1662+
// The only possible mismatch would be when the input element type is an
1663+
// integer but not `si64` and the `dtype` is not specified. Therefore, we
1664+
// directly convert the input to `si64`. Rest all cases are handled in the
1665+
// dtype definition for this op.
16451666
Value torchInput = convertTensorToDtype(
16461667
rewriter, loc, op.getSelf(),
16471668
rewriter.getIntegerType(64, IntegerType::Signed));
@@ -1650,12 +1671,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
16501671
torchInput);
16511672
}
16521673

1653-
int64_t inputRank = resultType.getRank();
1654-
Value dtype = op.getDtype();
1655-
if (!isa<Torch::NoneType>(dtype.getType()))
1656-
return rewriter.notifyMatchFailure(
1657-
op, "unsupported: dtype argument not supported");
1658-
1674+
int64_t inputRank = inputType.getRank();
16591675
int64_t dim;
16601676
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
16611677
return rewriter.notifyMatchFailure(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3579,6 +3579,7 @@
35793579
"ConvolutionModule2DTranspose_basic",
35803580
"ConvolutionModule2DGroupedTranspose_basic",
35813581
"CumsumInputDtypeInt32Module_basic",
3582+
"CumsumWithDtypeModule_basic",
35823583
"CumsumModule_basic",
35833584
"CumsumStaticModule_basic",
35843585
"CumsumStaticNegativeDimModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5025,6 +5025,26 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils):
50255025
module.forward(tu.randint(2, 7, 4).to(torch.int32))
50265026

50275027

5028+
class CumsumWithDtypeModule(torch.nn.Module):
5029+
def __init__(self):
5030+
super().__init__()
5031+
5032+
@export
5033+
@annotate_args(
5034+
[
5035+
None,
5036+
([2, 7, 4], torch.bool, True),
5037+
]
5038+
)
5039+
def forward(self, val):
5040+
return torch.ops.aten.cumsum(val, dim=1, dtype=6)
5041+
5042+
5043+
@register_test_case(module_factory=lambda: CumsumWithDtypeModule())
5044+
def CumsumWithDtypeModule_basic(module, tu: TestUtils):
5045+
module.forward(tu.randint(2, 7, 4, low=-1, high=10).to(torch.bool))
5046+
5047+
50285048
# ==============================================================================
50295049

50305050

0 commit comments

Comments
 (0)