Skip to content

Commit 8cc313f

Browse files
authored
[TORCH] Add Kullback-Leibler divergence loss support (llvm#4204)
This PR takes care of llvm#4203. - e2e support of **aten.kl_div** op supporting all reduction modes (`mean, sum, batchmean, none`) - `reduction: batchmean` requires special handling by calling op with `sum` and then dividing it by input `batch_size`. Some tests are failing and are marked either in expected failures or crashing set. - **config=linalg** | **RuntimeError**: attribute lookup is not defined on builtin | **LINALG_XFAIL_SET** - **config=torchdynamo** | **Error**: failed to legalize operation '`torch.aten.xlogy.Tensor`' | **TORCHDYNAMO_CRASHING_SET** - **config=onnx** | **RuntimeError**: aten::div() Expected a value of type 'number' for argument 'other' but instead found type 'Tensor' Position: 1 Value: tensor(1) Declaration: aten::div.Scalar(Tensor self, Scalar other) -> Tensor Cast error details: Cannot cast tensor(1) to number | **ONNX_XFAIL_SET** - **config=onnx_tosa** | **Error**: failed to legalize operation '`torch.aten.size.int`' that was explicitly marked illegal | **ONNX_TOSA_XFAIL_SET** --------- Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 867eb39 commit 8cc313f

File tree

9 files changed

+339
-0
lines changed

9 files changed

+339
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9506,6 +9506,32 @@ def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [
95069506
}];
95079507
}
95089508

9509+
def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [
9510+
AllowsTypeRefinement,
9511+
HasValueSemantics,
9512+
ReadOnly
9513+
]> {
9514+
let summary = "Generated op for `aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)`";
9515+
let arguments = (ins
9516+
AnyTorchTensorType:$self,
9517+
AnyTorchTensorType:$target,
9518+
Torch_IntType:$reduction,
9519+
Torch_BoolType:$log_target
9520+
);
9521+
let results = (outs
9522+
AnyTorchOptionalTensorType:$result
9523+
);
9524+
let hasCustomAssemblyFormat = 1;
9525+
let extraClassDefinition = [{
9526+
ParseResult AtenKlDivOp::parse(OpAsmParser &parser, OperationState &result) {
9527+
return parseDefaultTorchOp(parser, result, 4, 1);
9528+
}
9529+
void AtenKlDivOp::print(OpAsmPrinter &printer) {
9530+
printDefaultTorchOp(printer, *this, 4, 1);
9531+
}
9532+
}];
9533+
}
9534+
95099535
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
95109536
AllowsTypeRefinement,
95119537
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10692,6 +10692,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1069210692
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
1069310693
" return %0 : !torch.list<int>\n"
1069410694
" }\n"
10695+
" func.func @\"__torch_mlir_shape_fn.aten.kl_div\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.list<int> {\n"
10696+
" %none = torch.constant.none\n"
10697+
" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n"
10698+
" %int0 = torch.constant.int 0\n"
10699+
" %int1 = torch.constant.int 1\n"
10700+
" %int2 = torch.constant.int 2\n"
10701+
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
10702+
" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
10703+
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
10704+
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10705+
" torch.prim.If.yield %3 : !torch.list<int>\n"
10706+
" } else {\n"
10707+
" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
10708+
" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list<int>, !torch.int -> !torch.bool\n"
10709+
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
10710+
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10711+
" torch.prim.If.yield %6 : !torch.list<int>\n"
10712+
" } else {\n"
10713+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10714+
" torch.prim.If.yield %0 : !torch.list<int>\n"
10715+
" }\n"
10716+
" torch.prim.If.yield %5 : !torch.list<int>\n"
10717+
" }\n"
10718+
" return %2 : !torch.list<int>\n"
10719+
" }\n"
1069510720
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
1069610721
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
1069710722
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -14575,6 +14600,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1457514600
" }\n"
1457614601
" return %int3 : !torch.int\n"
1457714602
" }\n"
14603+
" func.func @\"__torch_mlir_dtype_fn.aten.kl_div\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.int {\n"
14604+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14605+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14606+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
14607+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
14608+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
14609+
" return %4 : !torch.int\n"
14610+
" }\n"
1457814611
" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
1457914612
" %none = torch.constant.none\n"
1458014613
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10629,6 +10629,84 @@ class DecomposeAtenPoissonNllLossOp
1062910629
};
1063010630
} // namespace
1063110631

10632+
namespace {
10633+
class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10634+
using OpRewritePattern::OpRewritePattern;
10635+
LogicalResult matchAndRewrite(AtenKlDivOp op,
10636+
PatternRewriter &rewriter) const override {
10637+
Location loc = op.getLoc();
10638+
Value self = op.getSelf();
10639+
Value target = op.getTarget();
10640+
Value reductionValue = op.getReduction();
10641+
Value logTargetValue = op.getLogTarget();
10642+
10643+
auto selfTy = cast<ValueTensorType>(self.getType());
10644+
auto targetTy = cast<ValueTensorType>(target.getType());
10645+
auto outTy = cast<ValueTensorType>(op.getType());
10646+
10647+
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
10648+
return rewriter.notifyMatchFailure(
10649+
op, "require self, target and output having sizes!");
10650+
}
10651+
10652+
if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) {
10653+
return rewriter.notifyMatchFailure(
10654+
op, "require self, target and output having dtype!");
10655+
}
10656+
10657+
// Extract boolean value from logTarget argument
10658+
bool logTargetBool;
10659+
if (!matchPattern(logTargetValue, m_TorchConstantBool(&logTargetBool)))
10660+
return rewriter.notifyMatchFailure(
10661+
op, "Expected a constant boolean value for logTargetBool");
10662+
10663+
// Default: target tensor is not in log space
10664+
Value logOfTarget;
10665+
if (!logTargetBool) {
10666+
logOfTarget = rewriter.create<AtenLogOp>(loc, targetTy, target);
10667+
} else {
10668+
logOfTarget = target;
10669+
}
10670+
10671+
Value constOne =
10672+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10673+
Value subValue = rewriter.create<AtenSubTensorOp>(loc, selfTy, logOfTarget,
10674+
self, constOne);
10675+
10676+
// if target tensor is already in log space
10677+
if (logTargetBool) {
10678+
target = rewriter.create<AtenExpOp>(loc, targetTy, target);
10679+
}
10680+
Value lossPointwise =
10681+
rewriter.create<AtenMulTensorOp>(loc, targetTy, target, subValue);
10682+
10683+
// Extract reduction int value from reduction argument
10684+
int64_t reduction;
10685+
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
10686+
return rewriter.notifyMatchFailure(op,
10687+
"reduction should be a constant int!");
10688+
}
10689+
10690+
Value loss;
10691+
Value none = rewriter.create<ConstantNoneOp>(loc);
10692+
// reduction: mean
10693+
if (reduction == 1) {
10694+
loss = rewriter.create<AtenMeanOp>(loc, outTy, lossPointwise, none);
10695+
} else if (reduction == 2) {
10696+
// reduction: sum
10697+
loss = rewriter.create<AtenSumOp>(loc, outTy, lossPointwise, none);
10698+
} else {
10699+
// reduction: none
10700+
loss = lossPointwise;
10701+
}
10702+
10703+
rewriter.replaceOp(op, loss);
10704+
10705+
return success();
10706+
}
10707+
};
10708+
} // namespace
10709+
1063210710
namespace {
1063310711
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1063410712
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12546,6 +12624,7 @@ class DecomposeComplexOpsPass
1254612624
addPatternIfTargetOpIsIllegal<DecomposeAtenPoissonNllLossOp>(patterns);
1254712625
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1254812626
patterns);
12627+
addPatternIfTargetOpIsIllegal<DecomposeAtenKlDivOp>(patterns);
1254912628
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
1255012629
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
1255112630
addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
587587
target.addIllegalOp<AtenFlipudOp>();
588588
target.addIllegalOp<AtenLogaddexpOp>();
589589
target.addIllegalOp<AtenLogaddexp2Op>();
590+
target.addIllegalOp<AtenKlDivOp>();
590591

591592
for (auto &opName : backendLegalOpsSet) {
592593
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
"AtenSymConstrainRange_basic",
4040
"AtenSymConstrainRangeForSize_basic",
4141
"Aten_AssertScalar_basic",
42+
# RuntimeError: attribute lookup is not defined on builtin:
43+
"KlDivLossModule_batchmean_reduction_basic",
4244
}
4345

4446
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
@@ -386,6 +388,12 @@
386388
"MaxPool3dStaticModule_basic",
387389
# Looks like incorrect fx graph conversion
388390
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
391+
# error: failed to legalize operation 'torch.aten.xlogy.Tensor'
392+
"KlDivLossModule_default_basic",
393+
"KlDivLossModule_reduction_is_none_basic",
394+
"KlDivLossModule_mean_reduction_basic",
395+
"KlDivLossModule_sum_reduction_basic",
396+
"KlDivLossModule_batchmean_reduction_basic",
389397
}
390398

391399
FX_IMPORTER_XFAIL_SET = {
@@ -3087,6 +3095,7 @@
30873095
"PoissonNLLLossMeanReductionModule_basic",
30883096
"PoissonNLLLossSumReductionModule_basic",
30893097
"PoissonNLLLossNonDefaultEpsModule_basic",
3098+
"KlDivLossModule_batchmean_reduction_basic",
30903099
"NormScalarComplexModule_basic",
30913100
"NormScalarModule_basic",
30923101
"NormScalarOptDimKeepDimComplexModule_basic",
@@ -3982,6 +3991,12 @@
39823991
"NllLossStaticModule_mean_basic",
39833992
"NllLossStaticModule_sum_basic",
39843993
"NllLossStaticModule_weight_basic",
3994+
"KlDivLossModule_default_basic",
3995+
"KlDivLossModule_reduction_is_none_basic",
3996+
"KlDivLossModule_reduction_is_none_log_target_is_true_basic",
3997+
"KlDivLossModule_mean_reduction_basic",
3998+
"KlDivLossModule_sum_reduction_basic",
3999+
"KlDivLossModule_batchmean_reduction_basic",
39854000
"Exp2StaticModule_basic",
39864001
"ElementwiseRreluWithNoiseEvalModule_basic",
39874002
"ElementwiseRreluWithNoiseEvalStaticModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,6 +2174,14 @@ def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti
21742174
def aten〇deg2rad〡shape(self: List[int]) -> List[int]:
21752175
return upstream_shape_functions.unary(self)
21762176

2177+
def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1, log_target: bool = False) -> List[int]:
2178+
if reduction == 0:
2179+
return upstream_shape_functions.unary(self)
2180+
elif reduction in [1, 2]:
2181+
return []
2182+
else:
2183+
assert False, "Invalid reduction value."
2184+
21772185
@check_shape_function([
21782186
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
21792187
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
@@ -4552,6 +4560,14 @@ def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tu
45524560
assert mat2_dtype == torch.int8
45534561
return torch.int32
45544562

4563+
def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, log_target: bool = False) -> int:
4564+
self_rank, self_dtype = self_rank_dtype
4565+
target_rank, target_dtype = target_rank_dtype
4566+
ranks: List[Optional[int]] = [self_rank, target_rank]
4567+
dtypes = [self_dtype, target_dtype]
4568+
promoted_dtype = promote_dtypes(ranks, dtypes)
4569+
return promoted_dtype
4570+
45554571
@check_dtype_function(_check_two_tensor_op(
45564572
output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
45574573
def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ def emit_with_mutating_variants(key, **kwargs):
764764
emit(
765765
"aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)"
766766
)
767+
emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)")
767768
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
768769
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
769770
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ def register_all_tests():
6262
from . import gridsampler
6363
from . import meshgrid
6464
from . import timeout
65+
from . import kl_div_loss

0 commit comments

Comments
 (0)