@@ -1634,14 +1634,35 @@ class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
1634
1634
auto resultType = cast<RankedTensorType>(
1635
1635
getTypeConverter ()->convertType (op->getResult (0 ).getType ()));
1636
1636
Type elementType = resultType.getElementType ();
1637
- Type inputElementType =
1638
- cast<RankedTensorType>(input. getType ()) .getElementType ();
1637
+ auto inputType = cast<RankedTensorType>(input. getType ());
1638
+ Type inputElementType = inputType .getElementType ();
1639
1639
1640
- // Converting the input element type to the result's element type.
1641
- // The only possible mismatch would be when the input element type is an
1642
- // integer but not `si64`. Therefore, we directly convert the input to
1643
- // `si64`. Rest all cases are handled in the dtype definition for this op.
1644
- if (elementType != inputElementType) {
1640
+ Value dtype = op.getDtype ();
1641
+ if (!isa<Torch::NoneType>(dtype.getType ())) {
1642
+ int64_t dtypeInt;
1643
+ if (!matchPattern (dtype, m_TorchConstantInt (&dtypeInt)))
1644
+ return rewriter.notifyMatchFailure (
1645
+ op, " unimplemented: only constant int dtype value is supported" );
1646
+
1647
+ FailureOr<Type> resDtype = getTypeForScalarType (
1648
+ op->getContext (), (torch_upstream::ScalarType)dtypeInt);
1649
+ if (failed (resDtype))
1650
+ return rewriter.notifyMatchFailure (
1651
+ op, " unsupported: dtype not defined for the given dtype int value" );
1652
+
1653
+ Value torchInput =
1654
+ convertTensorToDtype (rewriter, loc, op.getSelf (), resDtype.value ());
1655
+ input = typeConverter->materializeTargetConversion (
1656
+ rewriter, loc, typeConverter->convertType (torchInput.getType ()),
1657
+ torchInput);
1658
+ } else if (elementType != inputElementType &&
1659
+ isa<mlir::IntegerType>(elementType) &&
1660
+ isa<mlir::IntegerType>(inputElementType)) {
1661
+ // Converting the input element type to the result's element type.
1662
+ // The only possible mismatch would be when the input element type is an
1663
+ // integer but not `si64` and the `dtype` is not specified. Therefore, we
1664
+ // directly convert the input to `si64`. Rest all cases are handled in the
1665
+ // dtype definition for this op.
1645
1666
Value torchInput = convertTensorToDtype (
1646
1667
rewriter, loc, op.getSelf (),
1647
1668
rewriter.getIntegerType (64 , IntegerType::Signed));
@@ -1650,12 +1671,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
1650
1671
torchInput);
1651
1672
}
1652
1673
1653
- int64_t inputRank = resultType.getRank ();
1654
- Value dtype = op.getDtype ();
1655
- if (!isa<Torch::NoneType>(dtype.getType ()))
1656
- return rewriter.notifyMatchFailure (
1657
- op, " unsupported: dtype argument not supported" );
1658
-
1674
+ int64_t inputRank = inputType.getRank ();
1659
1675
int64_t dim;
1660
1676
if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
1661
1677
return rewriter.notifyMatchFailure (
0 commit comments