Skip to content

Commit a1769ad

Browse files
[Torch] Fix dtype mismatch when decomposing AtenNativeLayerNormOp for BF16 types (llvm#4168)
DecomposeAtenNativeLayerNormOp assumed that intermediate tensors (mean and rstd) had the same dtype as the input tensor, which is incorrect when inputs are BF16. The mean and rstd are computed in F32 even when inputs are BF16. This would result in emitting an invalid AtenExpandAsOp that changes the shape but also the dtype. This patch inserts explicit AtenToDtypeOp casts to convert the intermediate mean and rsdt tensors back to BF16 before further broadcasting (AtenExpandAs).
1 parent 4fe75dd commit a1769ad

File tree

2 files changed

+81
-3
lines changed

2 files changed

+81
-3
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7273,9 +7273,11 @@ class DecomposeAtenNativeLayerNormOp
72737273
Value inputMean = rewriter.create<AtenMeanDimOp>(
72747274
loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none);
72757275

7276+
Value inputMeanCasted =
7277+
convertTensorToDtype(rewriter, loc, inputMean, inputTy.getDtype());
72767278
// x - mean(x)
7277-
Value inputMeanExpanded =
7278-
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.getInput());
7279+
Value inputMeanExpanded = rewriter.create<AtenExpandAsOp>(
7280+
loc, inputTy, inputMeanCasted, op.getInput());
72797281
Value inputZeroMean = rewriter.create<AtenSubTensorOp>(
72807282
loc, inputTy, op.getInput(), inputMeanExpanded, one);
72817283
// var(x) = mean((x - mean(x))^2)
@@ -7290,9 +7292,11 @@ class DecomposeAtenNativeLayerNormOp
72907292
Value inputRsqrtVar =
72917293
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
72927294

7295+
Value inputRsqrtVarCasted =
7296+
convertTensorToDtype(rewriter, loc, inputRsqrtVar, inputTy.getDtype());
72937297
// (x - mean(x)) * rsqrt(var(x) + eps)
72947298
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
7295-
loc, inputTy, inputRsqrtVar, op.getInput());
7299+
loc, inputTy, inputRsqrtVarCasted, op.getInput());
72967300
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
72977301
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
72987302
// Convert resultType if dtype is different

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,77 @@ func.func @torch.aten.stft.center_2D_hop_length_3_window_pad_both(%arg0: !torch.
772772
%0 = torch.aten.stft.center %arg0, %nfft, %hoplen, %winlen, %arg1, %cstfalse, %padmode, %cstfalse, %cstfalse, %csttrue, %cstfalse : !torch.vtensor<[3,90],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[8],f32>, !torch.bool, !torch.str, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[3,6,27],complex<f32>>
773773
return %0 : !torch.vtensor<[3,6,27],complex<f32>>
774774
}
775+
776+
777+
// -----
778+
779+
780+
// CHECK-LABEL: func.func @native_layer_norm(
781+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,56,56,96],f32>, %[[ARG1:.*]]: !torch.list<int>, %[[ARG2:.*]]: !torch.vtensor<[96],f32>, %[[ARG3:.*]]: !torch.vtensor<[96],f32>, %[[ARG4:.*]]: !torch.float) -> (!torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
782+
// CHECK-DAG: %[[INT96:.*]] = torch.constant.int 96
783+
// CHECK-DAG: %[[INT56:.*]] = torch.constant.int 56
784+
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
785+
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
786+
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
787+
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
788+
// CHECK: %[[VAR1:.*]] = torch.aten.sum.dim_IntList %[[ARG0]], %[[VAR0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
789+
// CHECK: %[[VAR2:.*]] = torch.aten.numel %[[ARG0]] : !torch.vtensor<[1,56,56,96],f32> -> !torch.int
790+
// CHECK: %[[VAR3:.*]] = torch.aten.div.Scalar %[[VAR1]], %[[VAR2]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
791+
// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
792+
// CHECK: %[[VAR5:.*]] = torch.aten.broadcast_to %[[VAR3]], %[[VAR4]] : !torch.vtensor<[1,56,56,1],f32>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],f32>
793+
// CHECK: %[[VAR6:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAR5]], %int1 : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32>, !torch.int -> !torch.vtensor<[1,56,56,96],f32>
794+
// CHECK: %[[VAR7:.*]] = torch.aten.mul.Tensor %[[VAR6]], %[[VAR6]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32> -> !torch.vtensor<[1,56,56,96],f32>
795+
// CHECK: %[[VAR8:.*]] = torch.aten.sum.dim_IntList %[[VAR7]], %0, %true, %none : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
796+
// CHECK: %[[VAR9:.*]] = torch.aten.numel %7 : !torch.vtensor<[1,56,56,96],f32> -> !torch.int
797+
// CHECK: %[[VAR10:.*]] = torch.aten.div.Scalar %[[VAR8]], %[[VAR9]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
798+
// CHECK: %[[VAR11:.*]] = torch.aten.add.Scalar %[[VAR10]], %[[ARG4]], %int1 : !torch.vtensor<[1,56,56,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
799+
// CHECK: %[[VAR12:.*]] = torch.aten.rsqrt %[[VAR11]] : !torch.vtensor<[1,56,56,1],f32> -> !torch.vtensor<[1,56,56,1],f32>
800+
// CHECK: %[[VAR13:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
801+
// CHECK: %[[VAR14:.*]] = torch.aten.broadcast_to %[[VAR12]], %[[VAR13]] : !torch.vtensor<[1,56,56,1],f32>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],f32>
802+
// CHECK: %[[VAR15:.*]] = torch.aten.mul.Tensor %[[VAR6]], %[[VAR14]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32> -> !torch.vtensor<[1,56,56,96],f32>
803+
// CHECK: %[[VAR16:.*]] = torch.aten.mul.Tensor %[[VAR15]], %[[ARG2]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[96],f32> -> !torch.vtensor<[1,56,56,96],f32>
804+
// CHECK: %[[VAR17:.*]] = torch.aten.add.Tensor %[[VAR16]], %[[ARG3]], %int1 : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[96],f32>, !torch.int -> !torch.vtensor<[1,56,56,96],f32>
805+
// CHECK: return %[[VAR17]], %[[VAR3]], %[[VAR12]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
806+
func.func @native_layer_norm(%input: !torch.vtensor<[1,56,56,96],f32>, %normalized_shape: !torch.list<int>, %weight: !torch.vtensor<[96],f32>, %bias: !torch.vtensor<[96],f32>, %eps: !torch.float) -> (!torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
807+
%result, %mean, %rstd = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.vtensor<[96],f32>, !torch.vtensor<[96],f32>, !torch.float -> !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
808+
return %result, %mean, %rstd : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
809+
}
810+
811+
812+
// -----
813+
814+
815+
// CHECK-LABEL: func.func @native_layer_norm_mixed_dtypes(
816+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,56,56,96],bf16>, %[[ARG1:.*]]: !torch.list<int>, %[[ARG2:.*]]: !torch.vtensor<[96],bf16>, %[[ARG3:.*]]: !torch.vtensor<[96],bf16>, %[[ARG4:.*]]: !torch.float) -> (!torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
817+
// CHECK-DAG: %[[INT96:.*]] = torch.constant.int 96
818+
// CHECK-DAG: %[[INT56:.*]] = torch.constant.int 56
819+
// CHECK-DAG: %[[INT15:.*]] = torch.constant.int 15
820+
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
821+
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
822+
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
823+
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
824+
// CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
825+
// CHECK: %[[VAR1:.*]] = torch.aten.sum.dim_IntList %[[ARG0]], %[[VAR0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
826+
// CHECK: %[[VAR2:.*]] = torch.aten.numel %[[ARG0]] : !torch.vtensor<[1,56,56,96],bf16> -> !torch.int
827+
// CHECK: %[[VAR3:.*]] = torch.aten.div.Scalar %[[VAR1]], %[[VAR2]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
828+
// CHECK: %[[VAR4:.*]] = torch.aten.to.dtype %[[VAR3]], %[[INT15]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],bf16>
829+
// CHECK: %[[VAR5:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
830+
// CHECK: %[[VAR6:.*]] = torch.aten.broadcast_to %[[VAR4]], %[[VAR5]] : !torch.vtensor<[1,56,56,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],bf16>
831+
// CHECK: %[[VAR7:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAR6]], %int1 : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16>, !torch.int -> !torch.vtensor<[1,56,56,96],bf16>
832+
// CHECK: %[[VAR8:.*]] = torch.aten.mul.Tensor %[[VAR7]], %[[VAR7]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
833+
// CHECK: %[[VAR9:.*]] = torch.aten.sum.dim_IntList %[[VAR8]], %0, %true, %none : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
834+
// CHECK: %[[VAR10:.*]] = torch.aten.numel %8 : !torch.vtensor<[1,56,56,96],bf16> -> !torch.int
835+
// CHECK: %[[VAR11:.*]] = torch.aten.div.Scalar %[[VAR9]], %[[VAR10]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
836+
// CHECK: %[[VAR12:.*]] = torch.aten.add.Scalar %[[VAR11]], %[[ARG4]], %int1 : !torch.vtensor<[1,56,56,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
837+
// CHECK: %[[VAR13:.*]] = torch.aten.rsqrt %[[VAR12]] : !torch.vtensor<[1,56,56,1],f32> -> !torch.vtensor<[1,56,56,1],f32>
838+
// CHECK: %[[VAR14:.*]] = torch.aten.to.dtype %[[VAR13]], %[[INT15]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],bf16>
839+
// CHECK: %[[VAR15:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
840+
// CHECK: %[[VAR16:.*]] = torch.aten.broadcast_to %[[VAR14]], %[[VAR15]] : !torch.vtensor<[1,56,56,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],bf16>
841+
// CHECK: %[[VAR17:.*]] = torch.aten.mul.Tensor %[[VAR7]], %[[VAR16]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
842+
// CHECK: %[[VAR18:.*]] = torch.aten.mul.Tensor %[[VAR17]], %[[ARG2]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
843+
// CHECK: %[[VAR19:.*]] = torch.aten.add.Tensor %[[VAR18]], %[[ARG3]], %int1 : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[96],bf16>, !torch.int -> !torch.vtensor<[1,56,56,96],bf16>
844+
// CHECK: return %[[VAR19]], %[[VAR3]], %[[VAR13]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
845+
func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf16>, %normalized_shape: !torch.list<int>, %weight: !torch.vtensor<[96],bf16>, %bias: !torch.vtensor<[96],bf16>, %eps: !torch.float) -> (!torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
846+
%result, %mean, %rstd = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.vtensor<[96],bf16>, !torch.vtensor<[96],bf16>, !torch.float -> !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
847+
return %result, %mean, %rstd : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
848+
}

0 commit comments

Comments
 (0)