@@ -516,18 +516,17 @@ struct LstmLayerOutput {
516
516
//
517
517
// @return A struct containing the hidden state history, final hidden state,
518
518
// and final cell state.
519
- LstmLayerOutput lstm_layer (ImplicitLocOpBuilder &b, Value X, Value initial_h ,
520
- Value initial_c, LstmWeights weights ,
521
- LstmActivations activations) {
519
+ LstmLayerOutput lstm_layer (ConversionPatternRewriter &rewriter, Location &loc ,
520
+ Value X, Value initial_h, Value initial_c ,
521
+ LstmWeights weights, LstmActivations activations) {
522
522
523
- Location loc = b. getLoc ( );
523
+ mlir::ImplicitLocOpBuilder b (loc, rewriter );
524
524
525
- auto xTy = cast<ValueTensorType>(X.getType ());
526
525
auto hTy = cast<ValueTensorType>(initial_h.getType ());
527
526
// these names are snake_case for consistency with onnx.LSTM documentation
528
- int64_t seq_len = xTy. getSizes ()[ 0 ] ;
529
- int64_t batch_size = xTy. getSizes ()[ 1 ] ;
530
- int64_t input_size = xTy. getSizes ()[ 2 ] ;
527
+ Value seq_len = getTensorDimSize (rewriter, X, 0 ) ;
528
+ Value batch_size = getTensorDimSize (rewriter, X, 1 ) ;
529
+ Value input_size = getTensorDimSize (rewriter, X, 2 ) ;
531
530
int64_t hidden_size = hTy.getSizes ()[1 ];
532
531
533
532
auto cTy = hTy;
@@ -537,19 +536,14 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
537
536
Value cstNone = b.create <ConstantNoneOp>();
538
537
Value cstZero = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (0 ));
539
538
Value cstOne = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (1 ));
540
- Value cstSeqLen =
541
- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (seq_len));
542
- Value cstBatchSize =
543
- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (batch_size));
544
539
Value cstHiddenSize =
545
540
b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (hidden_size));
546
541
547
- auto yTy = b.getType <ValueTensorType>(
548
- SmallVector<int64_t >{seq_len, batch_size, hidden_size}, hTy.getDtype ());
549
-
542
+ auto yTy = getTensorTypeFromShapeValues ({seq_len, batch_size, cstHiddenSize},
543
+ hTy.getDtype ());
550
544
auto YShapeList = b.create <PrimListConstructOp>(
551
545
b.getType <ListType>(intType),
552
- ValueRange ({cstSeqLen, cstBatchSize , cstHiddenSize}));
546
+ ValueRange ({seq_len, batch_size , cstHiddenSize}));
553
547
554
548
int64_t hDtypeInt =
555
549
static_cast <int64_t >(getScalarTypeForType (hTy.getDtype ()));
@@ -560,8 +554,7 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
560
554
cstNone, cstNone, cstNone);
561
555
562
556
// Create a for-like PrimLoopOp.
563
- Value maxTripCount =
564
- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (seq_len));
557
+ Value maxTripCount = seq_len;
565
558
Value loopConditionTrue = b.create <ConstantBoolOp>(true );
566
559
567
560
Type loopIndexType = intType;
@@ -587,16 +580,16 @@ LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
587
580
Value C_prev = loopBody->getArgument (3 );
588
581
589
582
auto xTy = cast<ValueTensorType>(X.getType ());
590
- auto XtType = b. getType <ValueTensorType>(
591
- llvm::SmallVector< int64_t > {batch_size, input_size}, xTy.getDtype ());
583
+ auto XtType =
584
+ getTensorTypeFromShapeValues ( {batch_size, input_size}, xTy.getDtype ());
592
585
593
586
Value Xt = b.create <AtenSelectIntOp>(XtType, X, cstZero, loopIndex);
594
587
595
588
auto [H_new, C_new] =
596
589
lstm_cell (b, Xt, H_prev, C_prev, weights, activations);
597
590
598
- Type hTyUnsqueezed = b. getType <ValueTensorType> (
599
- llvm::SmallVector< int64_t >{ 1 , batch_size, hidden_size }, hTy.getDtype ());
591
+ auto hTyUnsqueezed = getTensorTypeFromShapeValues (
592
+ {cstOne , batch_size, cstHiddenSize }, hTy.getDtype ());
600
593
Value H_new_unsqueezed =
601
594
b.create <AtenUnsqueezeOp>(hTyUnsqueezed, H_new, cstZero);
602
595
@@ -773,17 +766,12 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
773
766
binder.op , " invalid value of layout attribute, expecting 0 / 1 got " +
774
767
std::to_string (layout));
775
768
776
- auto XShape = xTy.getSizes ();
777
- int64_t seq_len, batch_size;
778
- if (layout == 0 ) {
779
- seq_len = XShape[0 ];
780
- batch_size = XShape[1 ];
781
- } else {
782
- seq_len = XShape[1 ];
783
- batch_size = XShape[0 ];
784
- }
769
+ Value seqLen = getTensorDimSize (rewriter, X, layout == 0 ? 0 : 1 );
770
+ Value batchSize = getTensorDimSize (rewriter, X, layout == 0 ? 1 : 0 );
785
771
786
- int64_t input_size = XShape[2 ];
772
+ int64_t x_input_size = xTy.getSizes ()[2 ];
773
+ int64_t w_input_size = wTy.getSizes ()[2 ];
774
+ int64_t input_size = w_input_size;
787
775
if (num_directions != wTy.getSizes ()[0 ])
788
776
return rewriter.notifyMatchFailure (
789
777
binder.op , " num_directions (" + std::to_string (num_directions) +
@@ -795,11 +783,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
795
783
binder.op , " 4 times hidden_size (" + std::to_string (4 * hidden_size) +
796
784
" ) does not match the second dimension of wTy (" +
797
785
std::to_string (wTy.getSizes ()[1 ]) + " )" );
798
- if (wTy.getSizes ()[2 ] != input_size)
799
- return rewriter.notifyMatchFailure (
800
- binder.op ,
801
- " The third dimension of wTy (" + std::to_string (wTy.getSizes ()[2 ]) +
802
- " ) does not match input_size (" + std::to_string (input_size) + " )" );
786
+ if (x_input_size != Torch::kUnknownSize ) {
787
+ if (w_input_size != x_input_size)
788
+ return rewriter.notifyMatchFailure (
789
+ binder.op , " The input_size of wTy (" + std::to_string (w_input_size) +
790
+ " ) does not match input_size of xTY (" +
791
+ std::to_string (x_input_size) + " )" );
792
+
793
+ } else {
794
+ Value x_input_size = Torch::getTensorDimSize (rewriter, X, 2 );
795
+ Value w_input_size =
796
+ b.create <ConstantIntOp>(loc, b.getI64IntegerAttr (wTy.getSizes ()[2 ]));
797
+
798
+ auto eq = b.create <AtenEqIntOp>(loc, x_input_size, w_input_size);
799
+ rewriter.create <RuntimeAssertOp>(
800
+ loc, eq, rewriter.getStringAttr (" The input_size of W must equal X." ));
801
+ }
803
802
804
803
Value W_forward = getDirection (b, 0 , W);
805
804
Value R_forward = getDirection (b, 0 , R);
@@ -812,25 +811,21 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
812
811
B_reverse = getDirection (b, 1 , B);
813
812
}
814
813
815
- auto hTy = b.getType <ValueTensorType>(
816
- llvm::SmallVector<int64_t >{num_directions, batch_size, hidden_size},
817
- xTy.getDtype ());
818
-
819
814
auto intType = b.getType <IntType>();
820
815
821
816
Value cstNumDirections =
822
817
b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (num_directions));
823
- Value cstBatchSize =
824
- b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (batch_size));
825
818
Value cstHiddenSize =
826
819
b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (hidden_size));
827
820
Value cstNone = b.create <ConstantNoneOp>();
828
821
Value cstZero = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (0 ));
829
822
Value cstOne = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (1 ));
830
823
824
+ auto hTy = getTensorTypeFromShapeValues (
825
+ {cstNumDirections, batchSize, cstHiddenSize}, xTy.getDtype ());
831
826
Value hShape = b.create <PrimListConstructOp>(
832
827
b.getType <ListType>(intType),
833
- ValueRange ({cstNumDirections, cstBatchSize , cstHiddenSize}));
828
+ ValueRange ({cstNumDirections, batchSize , cstHiddenSize}));
834
829
835
830
Value cstDtype = getDtypeIntValueForType (rewriter, loc, xTy.getDtype ());
836
831
@@ -986,26 +981,26 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
986
981
std::tie (weightsRev.R_i , weightsRev.R_o , weightsRev.R_f , weightsRev.R_c ) =
987
982
sliceIOFC (sliceGateWeightsHH, R_reverse);
988
983
989
- LstmLayerOutput lstmLayerOutput = lstm_layer (
990
- b, X, initial_h_forward, initial_c_forward, weights, activations);
984
+ LstmLayerOutput lstmLayerOutput =
985
+ lstm_layer (rewriter, loc, X, initial_h_forward, initial_c_forward,
986
+ weights, activations);
991
987
992
988
Value Y_h_result, Y_c_result, Y_result;
993
989
994
990
// if forward (unidirectional) unsqueeze and output
995
991
auto YallDtype =
996
992
cast<ValueTensorType>(lstmLayerOutput.Y_h .getType ()).getDtype ();
997
- auto Y_h_Y_c_uni_type = b.getType <ValueTensorType>(
998
- llvm::SmallVector<int64_t >{1 , batch_size, hidden_size}, YallDtype);
999
- auto Y_uni_type = b.getType <ValueTensorType>(
1000
- llvm::SmallVector<int64_t >{seq_len, 1 , batch_size, hidden_size},
1001
- YallDtype);
1002
- auto Y_h_Y_c_res_type = b.getType <ValueTensorType>(
1003
- llvm::SmallVector<int64_t >{num_directions, batch_size, hidden_size},
1004
- YallDtype);
1005
- auto Y_res_type = b.getType <ValueTensorType>(
1006
- llvm::SmallVector<int64_t >{seq_len, num_directions, batch_size,
1007
- hidden_size},
1008
- YallDtype);
993
+ auto Y_h_Y_c_uni_type = getTensorTypeFromShapeValues (
994
+ {cstOne, batchSize, cstHiddenSize}, YallDtype);
995
+
996
+ auto Y_uni_type = getTensorTypeFromShapeValues (
997
+ {seqLen, cstOne, batchSize, cstHiddenSize}, YallDtype);
998
+
999
+ auto Y_h_Y_c_res_type = getTensorTypeFromShapeValues (
1000
+ {cstNumDirections, batchSize, cstHiddenSize}, YallDtype);
1001
+
1002
+ auto Y_res_type = getTensorTypeFromShapeValues (
1003
+ {seqLen, cstNumDirections, batchSize, cstHiddenSize}, YallDtype);
1009
1004
1010
1005
Value Y_h_forward =
1011
1006
b.create <AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h , cstZero);
@@ -1034,8 +1029,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
1034
1029
SmallVector<Value>{cstZero});
1035
1030
X_reverse = b.create <AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
1036
1031
revLstmLayerOutput =
1037
- lstm_layer (b, X_reverse, initial_h_reverse, initial_c_reverse ,
1038
- weightsRev, activationsRev);
1032
+ lstm_layer (rewriter, loc, X_reverse, initial_h_reverse ,
1033
+ initial_c_reverse, weightsRev, activationsRev);
1039
1034
1040
1035
// unsqueeze Y_rev, Y_h_rev, Y_c_rev
1041
1036
Y_h_reverse = b.create <AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
@@ -1081,7 +1076,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
1081
1076
outputs;
1082
1077
ValueTensorType resTy;
1083
1078
for (int i = 0 ; i < binder.getNumResults (); ++i) {
1084
- if (! binder.tensorResultTypeAtIndex (resTy, i) && !resTy ) {
1079
+ if (failed ( binder.tensorResultTypeAtIndex (resTy, i)) ) {
1085
1080
outputs.push_back (cstNone);
1086
1081
} else {
1087
1082
outputs.push_back (actualOutputs[i]);
0 commit comments