Skip to content

Commit c51c87f

Browse files
committed
submit the llvm#3902 to local repo (llvm#5)
* Decompose lstm and gru. * Add tests and update xfail_sets.py * Rebase main * Fix casting for arith.cmpi operands to be of same type.
1 parent 95674e2 commit c51c87f

File tree

5 files changed

+227
-32
lines changed

5 files changed

+227
-32
lines changed

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,21 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
417417
};
418418
} // namespace
419419

420+
static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
421+
Value input, int64_t dim) {
422+
// performs the operation : index = index % maxIndex to wrap index around
423+
// maxIndex
424+
Value maxIndexValue = getDimOp(b, loc, input, dim);
425+
maxIndexValue =
426+
b.createOrFold<arith::IndexCastOp>(loc, index.getType(), maxIndexValue);
427+
Value isBeyondMaxIndices = b.createOrFold<arith::CmpIOp>(
428+
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
429+
Value wrappedIndices =
430+
b.createOrFold<arith::RemSIOp>(loc, index, maxIndexValue);
431+
return b.createOrFold<arith::SelectOp>(loc, isBeyondMaxIndices,
432+
wrappedIndices, index);
433+
}
434+
420435
namespace {
421436
// Let's say we have an input tensor: initialized with some random values of
422437
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
@@ -478,16 +493,17 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
478493

479494
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr},
480495
rewriter.getContext());
481-
482496
Value finalRes =
483497
rewriter
484498
.create<linalg::GenericOp>(
485499
loc, initTensor.getType(), ValueRange{indices}, initTensor,
486500
/*indexingMaps=*/indexingMaps,
487501
/*iteratorTypes=*/iteratorTypes,
488502
[&](OpBuilder &b, Location loc, ValueRange args) {
489-
Value index = rewriter.create<arith::IndexCastOp>(
490-
loc, rewriter.getIndexType(), args[0]);
503+
Value index =
504+
wrapIndicesAroundMax(b, loc, args[0], input, dimInt);
505+
index = rewriter.create<arith::IndexCastOp>(
506+
loc, rewriter.getIndexType(), index);
491507
SmallVector<Value> indexTarget;
492508
for (unsigned i = 0; i < inputRank; i++)
493509
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4223,6 +4223,42 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
42234223
return success();
42244224
}
42254225

4226+
Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
4227+
ConversionPatternRewriter &rewriter) {
4228+
// performs the operation : index = index % maxIndex to wrap index around
4229+
// maxIndex
4230+
4231+
auto maxIndexValue =
4232+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
4233+
auto maxIndexValueMinusOne =
4234+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();
4235+
4236+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4237+
auto boolType = indexType.clone(rewriter.getIntegerType(1));
4238+
4239+
auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4240+
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
4241+
auto wrappedBeyondMaxIndicesQuotient =
4242+
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
4243+
index, maxIndexValue)
4244+
.getResult();
4245+
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
4246+
tosa::createMulOpAndCast(rewriter, op, indexType,
4247+
wrappedBeyondMaxIndicesQuotient,
4248+
wrappedBeyondMaxIndicesQuotient,
4249+
/*shift=*/0)
4250+
.getResult();
4251+
auto wrappedBeyondMaxIndices =
4252+
tosa::CreateOpAndInfer<tosa::SubOp>(
4253+
rewriter, op->getLoc(), indexType, index,
4254+
wrappedBeyondMaxIndicesQuotientTimesIndices)
4255+
.getResult();
4256+
4257+
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
4258+
indexType, isBeyondMaxIndices,
4259+
wrappedBeyondMaxIndices, index);
4260+
}
4261+
42264262
template <>
42274263
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
42284264
AtenIndexSelectOp op, OpAdaptor adaptor,
@@ -4271,6 +4307,10 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
42714307
.value();
42724308
}
42734309

4310+
int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(),
4311+
1, std::multiplies<int64_t>());
4312+
index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter);
4313+
42744314
// Get positive dim
42754315
int64_t dim;
42764316
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
@@ -7704,10 +7744,12 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
77047744
// coord_i_n * stride[n]
77057745
int32_t index = offset;
77067746
int64_t coordFinder = i;
7747+
77077748
for (int64_t dim = 0; dim < outputRank; dim++) {
77087749
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
77097750
index += indexCoord * stride[outputRank - dim - 1];
77107751
coordFinder /= outputSize[outputRank - dim - 1];
7752+
index = (index % selfNumElems);
77117753
}
77127754
targetIndicesVec.push_back(index);
77137755
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@
497497
"SplitTensorNegativeDimModule_basic",
498498
"SplitWithSizesListUnpackModule_basic",
499499
"SplitWithSizes_Module_basic",
500+
"AsStridedWithOffsetModule_basic",
500501
"AdaptiveAvgPool1dGeneralDynamic_basic",
501502
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
502503
"AdaptiveAvgPool1dStaticLargerOutput_basic",
@@ -930,6 +931,7 @@
930931
"SplitTensorNegativeDimModule_basic",
931932
"SplitWithSizesListUnpackModule_basic",
932933
"SplitWithSizes_Module_basic",
934+
"AsStridedWithOffsetModule_basic",
933935
"Unfold_Module_basic",
934936
"Unfold_Module_Rank_4",
935937
"Unfold_Module_Rank_Zero_basic",
@@ -1846,6 +1848,7 @@
18461848
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
18471849
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
18481850
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
1851+
"AsStridedWithOffsetModule_basic",
18491852
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
18501853
"ElementwiseCosIntModule_basic",
18511854
"ElementwiseReciprocalIntModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,3 +1144,32 @@ def forward(self, x):
11441144
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
11451145
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
11461146
module.forward(tu.rand(2, 5))
1147+
1148+
1149+
# ==============================================================================
1150+
1151+
1152+
class AsStridedWithOffsetModule(torch.nn.Module):
1153+
def __init__(self):
1154+
super().__init__()
1155+
1156+
@export
1157+
@annotate_args(
1158+
[
1159+
None,
1160+
([2, 6, 60], torch.float32, True),
1161+
]
1162+
)
1163+
def forward(self, x):
1164+
output_size = [6, 20]
1165+
stride = [60, 1]
1166+
slice = torch.ops.aten.slice.Tensor(x, 0, 1, 2)
1167+
squeeze = torch.ops.aten.squeeze.dim(slice, 0)
1168+
return torch.ops.aten.as_strided(
1169+
squeeze, size=output_size, stride=stride, storage_offset=360
1170+
)
1171+
1172+
1173+
@register_test_case(module_factory=lambda: AsStridedWithOffsetModule())
1174+
def AsStridedWithOffsetModule_basic(module, tu: TestUtils):
1175+
module.forward(torch.rand(2, 6, 60))

0 commit comments

Comments
 (0)