Skip to content

Commit 4042304

Browse files
[TORCH] Add f8 support in getConstantWithGivenDtypeAndValue utility (#4148)
Fixes iree-org/iree#20570. --------- Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent c632c86 commit 4042304

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,9 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
338338
dtype.isInteger(8) || dtype.isInteger(1))
339339
return rewriter.create<ConstantIntOp>(
340340
loc, rewriter.getI64IntegerAttr((int64_t)value));
341-
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())
341+
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16() ||
342+
isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
343+
Float8E4M3FNUZType>(dtype))
342344
return rewriter.create<ConstantFloatOp>(loc,
343345
rewriter.getF64FloatAttr(value));
344346
llvm::report_fatal_error(

test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,13 @@ func.func @torch.uint8(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[
2525
%1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
2626
return %1 : !torch.tensor
2727
}
28+
29+
// CHECK-LABEL: func.func @torch.f8type
30+
func.func @torch.f8type(%arg0: !torch.vtensor<[5,3],f8E4M3FNUZ>) -> !torch.vtensor<[5,3],f8E4M3FNUZ> {
31+
// CHECK: torch.aten.exp
32+
// CHECK: torch.aten.log1p
33+
// CHECK: torch.aten.tanh
34+
// CHECK-SAME: !torch.vtensor<[5,3],f8E4M3FNUZ>
35+
%0 = torch.aten.mish %arg0 : !torch.vtensor<[5,3],f8E4M3FNUZ> -> !torch.vtensor<[5,3],f8E4M3FNUZ>
36+
return %0 : !torch.vtensor<[5,3],f8E4M3FNUZ>
37+
}

0 commit comments

Comments
 (0)