Skip to content

Commit 5eae636

Browse files
authored
[ONNX] Add support for dynamic dimensions in onnx.LSTM (#4150)
Fixes nod-ai/SHARK-ModelDev#947. - Implement dynamic dimension support for input tensor `X`
1 parent 1a45fbc commit 5eae636

File tree

2 files changed

+80
-61
lines changed

2 files changed

+80
-61
lines changed

lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp

Lines changed: 56 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -516,18 +516,17 @@ struct LstmLayerOutput {
516516
//
517517
// @return A struct containing the hidden state history, final hidden state,
518518
// and final cell state.
519-
LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
520-
Value initial_c, LstmWeights weights,
521-
LstmActivations activations) {
519+
LstmLayerOutput lstm_layer(ConversionPatternRewriter &rewriter, Location &loc,
520+
Value X, Value initial_h, Value initial_c,
521+
LstmWeights weights, LstmActivations activations) {
522522

523-
Location loc = b.getLoc();
523+
mlir::ImplicitLocOpBuilder b(loc, rewriter);
524524

525-
auto xTy = cast<ValueTensorType>(X.getType());
526525
auto hTy = cast<ValueTensorType>(initial_h.getType());
527526
// these names are snake_case for consistency with onnx.LSTM documentation
528-
int64_t seq_len = xTy.getSizes()[0];
529-
int64_t batch_size = xTy.getSizes()[1];
530-
int64_t input_size = xTy.getSizes()[2];
527+
Value seq_len = getTensorDimSize(rewriter, X, 0);
528+
Value batch_size = getTensorDimSize(rewriter, X, 1);
529+
Value input_size = getTensorDimSize(rewriter, X, 2);
531530
int64_t hidden_size = hTy.getSizes()[1];
532531

533532
auto cTy = hTy;
@@ -537,19 +536,14 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
537536
Value cstNone = b.create<ConstantNoneOp>();
538537
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
539538
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
540-
Value cstSeqLen =
541-
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(seq_len));
542-
Value cstBatchSize =
543-
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(batch_size));
544539
Value cstHiddenSize =
545540
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(hidden_size));
546541

547-
auto yTy = b.getType<ValueTensorType>(
548-
SmallVector<int64_t>{seq_len, batch_size, hidden_size}, hTy.getDtype());
549-
542+
auto yTy = getTensorTypeFromShapeValues({seq_len, batch_size, cstHiddenSize},
543+
hTy.getDtype());
550544
auto YShapeList = b.create<PrimListConstructOp>(
551545
b.getType<ListType>(intType),
552-
ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize}));
546+
ValueRange({seq_len, batch_size, cstHiddenSize}));
553547

554548
int64_t hDtypeInt =
555549
static_cast<int64_t>(getScalarTypeForType(hTy.getDtype()));
@@ -560,8 +554,7 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
560554
cstNone, cstNone, cstNone);
561555

562556
// Create a for-like PrimLoopOp.
563-
Value maxTripCount =
564-
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(seq_len));
557+
Value maxTripCount = seq_len;
565558
Value loopConditionTrue = b.create<ConstantBoolOp>(true);
566559

567560
Type loopIndexType = intType;
@@ -587,16 +580,16 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
587580
Value C_prev = loopBody->getArgument(3);
588581

589582
auto xTy = cast<ValueTensorType>(X.getType());
590-
auto XtType = b.getType<ValueTensorType>(
591-
llvm::SmallVector<int64_t>{batch_size, input_size}, xTy.getDtype());
583+
auto XtType =
584+
getTensorTypeFromShapeValues({batch_size, input_size}, xTy.getDtype());
592585

593586
Value Xt = b.create<AtenSelectIntOp>(XtType, X, cstZero, loopIndex);
594587

595588
auto [H_new, C_new] =
596589
lstm_cell(b, Xt, H_prev, C_prev, weights, activations);
597590

598-
Type hTyUnsqueezed = b.getType<ValueTensorType>(
599-
llvm::SmallVector<int64_t>{1, batch_size, hidden_size}, hTy.getDtype());
591+
auto hTyUnsqueezed = getTensorTypeFromShapeValues(
592+
{cstOne, batch_size, cstHiddenSize}, hTy.getDtype());
600593
Value H_new_unsqueezed =
601594
b.create<AtenUnsqueezeOp>(hTyUnsqueezed, H_new, cstZero);
602595

@@ -773,17 +766,12 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
773766
binder.op, "invalid value of layout attribute, expecting 0 / 1 got " +
774767
std::to_string(layout));
775768

776-
auto XShape = xTy.getSizes();
777-
int64_t seq_len, batch_size;
778-
if (layout == 0) {
779-
seq_len = XShape[0];
780-
batch_size = XShape[1];
781-
} else {
782-
seq_len = XShape[1];
783-
batch_size = XShape[0];
784-
}
769+
Value seqLen = getTensorDimSize(rewriter, X, layout == 0 ? 0 : 1);
770+
Value batchSize = getTensorDimSize(rewriter, X, layout == 0 ? 1 : 0);
785771

786-
int64_t input_size = XShape[2];
772+
int64_t x_input_size = xTy.getSizes()[2];
773+
int64_t w_input_size = wTy.getSizes()[2];
774+
int64_t input_size = w_input_size;
787775
if (num_directions != wTy.getSizes()[0])
788776
return rewriter.notifyMatchFailure(
789777
binder.op, "num_directions (" + std::to_string(num_directions) +
@@ -795,11 +783,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
795783
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
796784
") does not match the second dimension of wTy (" +
797785
std::to_string(wTy.getSizes()[1]) + ")");
798-
if (wTy.getSizes()[2] != input_size)
799-
return rewriter.notifyMatchFailure(
800-
binder.op,
801-
"The third dimension of wTy (" + std::to_string(wTy.getSizes()[2]) +
802-
") does not match input_size (" + std::to_string(input_size) + ")");
786+
if (x_input_size != Torch::kUnknownSize) {
787+
if (w_input_size != x_input_size)
788+
return rewriter.notifyMatchFailure(
789+
binder.op, "The input_size of wTy (" + std::to_string(w_input_size) +
790+
") does not match input_size of xTY (" +
791+
std::to_string(x_input_size) + ")");
792+
793+
} else {
794+
Value x_input_size = Torch::getTensorDimSize(rewriter, X, 2);
795+
Value w_input_size =
796+
b.create<ConstantIntOp>(loc, b.getI64IntegerAttr(wTy.getSizes()[2]));
797+
798+
auto eq = b.create<AtenEqIntOp>(loc, x_input_size, w_input_size);
799+
rewriter.create<RuntimeAssertOp>(
800+
loc, eq, rewriter.getStringAttr("The input_size of W must equal X."));
801+
}
803802

804803
Value W_forward = getDirection(b, 0, W);
805804
Value R_forward = getDirection(b, 0, R);
@@ -812,25 +811,21 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
812811
B_reverse = getDirection(b, 1, B);
813812
}
814813

815-
auto hTy = b.getType<ValueTensorType>(
816-
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
817-
xTy.getDtype());
818-
819814
auto intType = b.getType<IntType>();
820815

821816
Value cstNumDirections =
822817
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions));
823-
Value cstBatchSize =
824-
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(batch_size));
825818
Value cstHiddenSize =
826819
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(hidden_size));
827820
Value cstNone = b.create<ConstantNoneOp>();
828821
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));
829822
Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));
830823

824+
auto hTy = getTensorTypeFromShapeValues(
825+
{cstNumDirections, batchSize, cstHiddenSize}, xTy.getDtype());
831826
Value hShape = b.create<PrimListConstructOp>(
832827
b.getType<ListType>(intType),
833-
ValueRange({cstNumDirections, cstBatchSize, cstHiddenSize}));
828+
ValueRange({cstNumDirections, batchSize, cstHiddenSize}));
834829

835830
Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype());
836831

@@ -986,26 +981,26 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
986981
std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) =
987982
sliceIOFC(sliceGateWeightsHH, R_reverse);
988983

989-
LstmLayerOutput lstmLayerOutput = lstm_layer(
990-
b, X, initial_h_forward, initial_c_forward, weights, activations);
984+
LstmLayerOutput lstmLayerOutput =
985+
lstm_layer(rewriter, loc, X, initial_h_forward, initial_c_forward,
986+
weights, activations);
991987

992988
Value Y_h_result, Y_c_result, Y_result;
993989

994990
// if forward (unidirectional) unsqueeze and output
995991
auto YallDtype =
996992
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype();
997-
auto Y_h_Y_c_uni_type = b.getType<ValueTensorType>(
998-
llvm::SmallVector<int64_t>{1, batch_size, hidden_size}, YallDtype);
999-
auto Y_uni_type = b.getType<ValueTensorType>(
1000-
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
1001-
YallDtype);
1002-
auto Y_h_Y_c_res_type = b.getType<ValueTensorType>(
1003-
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
1004-
YallDtype);
1005-
auto Y_res_type = b.getType<ValueTensorType>(
1006-
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
1007-
hidden_size},
1008-
YallDtype);
993+
auto Y_h_Y_c_uni_type = getTensorTypeFromShapeValues(
994+
{cstOne, batchSize, cstHiddenSize}, YallDtype);
995+
996+
auto Y_uni_type = getTensorTypeFromShapeValues(
997+
{seqLen, cstOne, batchSize, cstHiddenSize}, YallDtype);
998+
999+
auto Y_h_Y_c_res_type = getTensorTypeFromShapeValues(
1000+
{cstNumDirections, batchSize, cstHiddenSize}, YallDtype);
1001+
1002+
auto Y_res_type = getTensorTypeFromShapeValues(
1003+
{seqLen, cstNumDirections, batchSize, cstHiddenSize}, YallDtype);
10091004

10101005
Value Y_h_forward =
10111006
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero);
@@ -1034,8 +1029,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
10341029
SmallVector<Value>{cstZero});
10351030
X_reverse = b.create<AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
10361031
revLstmLayerOutput =
1037-
lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse,
1038-
weightsRev, activationsRev);
1032+
lstm_layer(rewriter, loc, X_reverse, initial_h_reverse,
1033+
initial_c_reverse, weightsRev, activationsRev);
10391034

10401035
// unsqueeze Y_rev, Y_h_rev, Y_c_rev
10411036
Y_h_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
@@ -1081,7 +1076,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
10811076
outputs;
10821077
ValueTensorType resTy;
10831078
for (int i = 0; i < binder.getNumResults(); ++i) {
1084-
if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) {
1079+
if (failed(binder.tensorResultTypeAtIndex(resTy, i))) {
10851080
outputs.push_back(cstNone);
10861081
} else {
10871082
outputs.push_back(actualOutputs[i]);

test/Conversion/TorchOnnxToTorch/ops/lstm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,27 @@ func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %
8484
%0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>)
8585
return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
8686
}
87+
88+
// -----
89+
90+
// CHECK-LABEL: func.func @test_lstm_dynamic(
91+
// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[?,?,?],f32>,
92+
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,12,4],f32>,
93+
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,12,3],f32>,
94+
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[1,24],f32>)
95+
// CHECK: torch.runtime.assert %[[EQ:.*]], "The input_size of W must equal X."
96+
// CHECK: %[[LOOP_RESULT:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS:.*]], %[[ENTER_LOOP:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) {
97+
// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[?,?,3],f32>, %[[H_PREV:.*]]: !torch.vtensor<[?,3],f32>, %[[C_PREV:.*]]: !torch.vtensor<[?,3],f32>):
98+
// CHECK-DAG: torch.aten.select.int
99+
// CHECK-DAG: torch.aten.linear
100+
// CHECK-DAG: torch.aten.sigmoid
101+
// CHECK-DAG: torch.aten.tanh
102+
// CHECK-DAG: torch.prim.Loop.condition
103+
// CHECK-DAG: }
104+
// CHECK: }
105+
106+
func.func @test_lstm_dynamic(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[?,1,?,3],f32>, !torch.vtensor<[1,?,3],f32>, !torch.vtensor<[1,?,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64} {
107+
%none = torch.constant.none
108+
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) { torch.onnx.hidden_size = 3 : si64 }: (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>)-> (!torch.vtensor<[?,1,?,3],f32>, !torch.vtensor<[1,?,3],f32>, !torch.vtensor<[1,?,3],f32>)
109+
return %0#0, %0#1, %0#2 : !torch.vtensor<[?,1,?,3],f32>, !torch.vtensor<[1,?,3],f32>, !torch.vtensor<[1,?,3],f32>
110+
}

0 commit comments

Comments
 (0)