Skip to content

Commit 09c339e

Browse files
[TORCH] Add support for PoissonNLLLoss Op (llvm#4232)
Fixes #llvm#4209 This change introduces end-to-end support in Torch‑MLIR for the `torch.nn.PoissonNLLLoss` op. - Decomposing PoissonNLLLoss op into Aten ops. - TODO : stirling approximation [[source]](https://docs.pytorch.org/docs/stable/generated/torch.nn.PoissonNLLLoss.html)
1 parent acf7fdd commit 09c339e

File tree

8 files changed

+278
-0
lines changed

8 files changed

+278
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9478,6 +9478,34 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
94789478
}];
94799479
}
94809480

9481+
def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [
9482+
AllowsTypeRefinement,
9483+
HasValueSemantics,
9484+
ReadOnly
9485+
]> {
9486+
let summary = "Generated op for `aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)`";
9487+
let arguments = (ins
9488+
AnyTorchTensorType:$input,
9489+
AnyTorchTensorType:$target,
9490+
Torch_BoolType:$log_input,
9491+
Torch_BoolType:$full,
9492+
Torch_FloatType:$eps,
9493+
Torch_IntType:$reduction
9494+
);
9495+
let results = (outs
9496+
AnyTorchOptionalTensorType:$result
9497+
);
9498+
let hasCustomAssemblyFormat = 1;
9499+
let extraClassDefinition = [{
9500+
ParseResult AtenPoissonNllLossOp::parse(OpAsmParser &parser, OperationState &result) {
9501+
return parseDefaultTorchOp(parser, result, 6, 1);
9502+
}
9503+
void AtenPoissonNllLossOp::print(OpAsmPrinter &printer) {
9504+
printDefaultTorchOp(printer, *this, 6, 1);
9505+
}
9506+
}];
9507+
}
9508+
94819509
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
94829510
AllowsTypeRefinement,
94839511
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10740,6 +10740,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1074010740
" }\n"
1074110741
" return %2 : !torch.list<int>\n"
1074210742
" }\n"
10743+
" func.func @\"__torch_mlir_shape_fn.aten.poisson_nll_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.float, %arg5: !torch.int) -> !torch.list<int> {\n"
10744+
" %int0 = torch.constant.int 0\n"
10745+
" %0 = torch.aten.eq.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n"
10746+
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
10747+
" torch.prim.If.yield %arg0 : !torch.list<int>\n"
10748+
" } else {\n"
10749+
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10750+
" torch.prim.If.yield %2 : !torch.list<int>\n"
10751+
" }\n"
10752+
" return %1 : !torch.list<int>\n"
10753+
" }\n"
1074310754
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
1074410755
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
1074510756
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
@@ -15267,6 +15278,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1526715278
" %4 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
1526815279
" return %4 : !torch.tuple<int, int>\n"
1526915280
" }\n"
15281+
" func.func @\"__torch_mlir_dtype_fn.aten.poisson_nll_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.float, %arg5: !torch.int) -> !torch.int {\n"
15282+
" %int6 = torch.constant.int 6\n"
15283+
" %int15 = torch.constant.int 15\n"
15284+
" %int5 = torch.constant.int 5\n"
15285+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15286+
" %1 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15287+
" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
15288+
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
15289+
" torch.prim.If.yield %int6 : !torch.int\n"
15290+
" } else {\n"
15291+
" torch.prim.If.yield %0#1 : !torch.int\n"
15292+
" }\n"
15293+
" return %3 : !torch.int\n"
15294+
" }\n"
1527015295
" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float) -> !torch.tuple<int, int, int> {\n"
1527115296
" %int7 = torch.constant.int 7\n"
1527215297
" %int10 = torch.constant.int 10\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10553,6 +10553,82 @@ class DecomposeAtenNllLossForwardOp
1055310553
};
1055410554
} // namespace
1055510555

10556+
namespace {
10557+
class DecomposeAtenPoissonNllLossOp
10558+
: public OpRewritePattern<AtenPoissonNllLossOp> {
10559+
public:
10560+
using OpRewritePattern::OpRewritePattern;
10561+
LogicalResult matchAndRewrite(AtenPoissonNllLossOp op,
10562+
PatternRewriter &rewriter) const override {
10563+
Location loc = op.getLoc();
10564+
Value input = op.getInput();
10565+
Value target = op.getTarget();
10566+
Value logInput = op.getLogInput();
10567+
Value full = op.getFull();
10568+
Value reduction = op.getReduction();
10569+
Value eps = op.getEps();
10570+
10571+
bool logInVal, fullVal;
10572+
if (!matchPattern(logInput, m_TorchConstantBool(&logInVal)))
10573+
return rewriter.notifyMatchFailure(
10574+
op, "expected logInput argument to be constant bool");
10575+
if (!matchPattern(full, m_TorchConstantBool(&fullVal)))
10576+
return rewriter.notifyMatchFailure(
10577+
op, "expected full argument to be constant bool");
10578+
10579+
int64_t reductionInt;
10580+
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt)))
10581+
return rewriter.notifyMatchFailure(op, "expected constant reduction");
10582+
10583+
double epsFloat;
10584+
if (!matchPattern(eps, m_TorchConstantFloat(&epsFloat))) {
10585+
return rewriter.notifyMatchFailure(op, "expected constant eps");
10586+
}
10587+
// TODO: add support for full=true (Stirling approximation)
10588+
if (fullVal)
10589+
return rewriter.notifyMatchFailure(
10590+
op, "Unimplemented: full loss computation is not supported");
10591+
10592+
Value one =
10593+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
10594+
Value epsConst = rewriter.create<ConstantFloatOp>(
10595+
loc, rewriter.getF64FloatAttr(epsFloat));
10596+
10597+
Value safeInput = rewriter.create<AtenAddScalarOp>(loc, input.getType(),
10598+
input, epsConst, one);
10599+
10600+
Value loss;
10601+
if (logInVal) {
10602+
Value expIn = rewriter.create<AtenExpOp>(loc, input.getType(), input);
10603+
Value targetMulInput =
10604+
rewriter.create<AtenMulTensorOp>(loc, input.getType(), target, input);
10605+
loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), expIn,
10606+
targetMulInput, one);
10607+
} else {
10608+
Value logSafeInput =
10609+
rewriter.create<AtenLogOp>(loc, input.getType(), safeInput);
10610+
Value targetMulLog = rewriter.create<AtenMulTensorOp>(
10611+
loc, input.getType(), target, logSafeInput);
10612+
loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), input,
10613+
targetMulLog, one);
10614+
}
10615+
10616+
Value result = loss;
10617+
if (reductionInt == 1) {
10618+
// Case 1: Mean Reduction
10619+
result = rewriter.create<AtenMeanOp>(
10620+
loc, op.getType(), loss, rewriter.create<ConstantNoneOp>(loc));
10621+
} else if (reductionInt == 2) {
10622+
// Case 2: Sum Reduction
10623+
result = rewriter.create<AtenSumOp>(loc, op.getType(), loss,
10624+
rewriter.create<ConstantNoneOp>(loc));
10625+
}
10626+
rewriter.replaceOp(op, result);
10627+
return success();
10628+
}
10629+
};
10630+
} // namespace
10631+
1055610632
namespace {
1055710633
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1055810634
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12467,6 +12543,7 @@ class DecomposeComplexOpsPass
1246712543
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
1246812544
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
1246912545
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
12546+
addPatternIfTargetOpIsIllegal<DecomposeAtenPoissonNllLossOp>(patterns);
1247012547
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1247112548
patterns);
1247212549
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
538538
target.addIllegalOp<AtenLerpTensorOp>();
539539
target.addIllegalOp<AtenMseLossOp>();
540540
target.addIllegalOp<AtenL1LossOp>();
541+
target.addIllegalOp<AtenPoissonNllLossOp>();
541542
target.addIllegalOp<AtenRandintLowOp>();
542543
target.addIllegalOp<AtenRandintOp>();
543544
target.addIllegalOp<AtenVarMeanCorrectionOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,6 +3083,10 @@
30833083
"NllLossStaticModule_mean_basic",
30843084
"NllLossModule_sum_basic",
30853085
"NllLossStaticModule_sum_basic",
3086+
"PoissonNLLLossNoReductionModule_basic",
3087+
"PoissonNLLLossMeanReductionModule_basic",
3088+
"PoissonNLLLossSumReductionModule_basic",
3089+
"PoissonNLLLossNonDefaultEpsModule_basic",
30863090
"NormScalarComplexModule_basic",
30873091
"NormScalarModule_basic",
30883092
"NormScalarOptDimKeepDimComplexModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,16 @@ def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: Lis
22082208
result_shape = scalar_shape
22092209
return result_shape
22102210

2211+
@check_shape_function([
2212+
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 0), # No reduction
2213+
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 1), # Mean reduction
2214+
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 2), # Sum reduction
2215+
])
2216+
def aten〇poisson_nll_loss〡shape(input: List[int], target: List[int], log_input: bool, full: bool, eps: float, reduction: int) -> List[int]:
2217+
if reduction == 0:
2218+
return input
2219+
return []
2220+
22112221
@check_shape_function([
22122222
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
22132223
])
@@ -5113,6 +5123,20 @@ def aten〇nll_loss_forward〡dtype(self_rank_dtype: Tuple[int, int], target_ran
51135123
assert target_dtype == torch.int64 or target_dtype == torch.int32
51145124
return self_dtype, self_dtype
51155125

5126+
@check_dtype_function([
5127+
Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, 3, dtype=torch.float32), # No reduction
5128+
True, False, 1e-8, 0),
5129+
Invocation(TensorOfShape(4, 5, dtype=torch.float32), TensorOfShape(4, 5, dtype=torch.float32), # Mean reduction
5130+
True, False, 1e-8, 1),
5131+
Invocation(TensorOfShape(3, 3, dtype=torch.float64), TensorOfShape(3, 3, dtype=torch.float64), # Sum reduction
5132+
True, False, 1e-8, 2),
5133+
])
5134+
def aten〇poisson_nll_loss〡dtype(input_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], log_input: bool, full: bool, eps: float, reduction: int) -> int:
5135+
_, input_dtype = input_rank_dtype
5136+
if input_dtype in (torch.float16, torch.bfloat16):
5137+
return torch.float32
5138+
return input_dtype
5139+
51165140
@check_dtype_function(
51175141
[Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32),
51185142
TensorOfShape(3, dtype=torch.float32), eps=0.0),

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,9 @@ def emit_with_mutating_variants(key, **kwargs):
761761
emit(
762762
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
763763
)
764+
emit(
765+
"aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)"
766+
)
764767
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
765768
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
766769
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,3 +670,119 @@ def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
670670
module.forward(
671671
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
672672
)
673+
674+
675+
class PoissonNLLLossNoReductionModule(torch.nn.Module):
676+
def __init__(self):
677+
super().__init__()
678+
679+
@export
680+
@annotate_args(
681+
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
682+
)
683+
def forward(self, input, target):
684+
return torch.ops.aten.poisson_nll_loss(
685+
input=input,
686+
target=target,
687+
log_input=False,
688+
full=False,
689+
eps=1e-8,
690+
reduction=0,
691+
)
692+
693+
694+
@register_test_case(module_factory=lambda: PoissonNLLLossNoReductionModule())
695+
def PoissonNLLLossNoReductionModule_basic(
696+
module: PoissonNLLLossNoReductionModule, tu: TestUtils
697+
):
698+
input = tu.rand(4, 3).abs()
699+
target = torch.poisson(input)
700+
module.forward(input, target)
701+
702+
703+
class PoissonNLLLossMeanReductionModule(torch.nn.Module):
704+
def __init__(self):
705+
super().__init__()
706+
707+
@export
708+
@annotate_args(
709+
[
710+
None,
711+
([-1, -1], torch.float32, True),
712+
([-1, -1], torch.float32, True),
713+
]
714+
)
715+
def forward(self, input, target):
716+
return torch.ops.aten.poisson_nll_loss(
717+
input=input,
718+
target=target,
719+
log_input=True,
720+
full=False,
721+
eps=1e-8,
722+
reduction=1,
723+
)
724+
725+
726+
@register_test_case(module_factory=lambda: PoissonNLLLossMeanReductionModule())
727+
def PoissonNLLLossMeanReductionModule_basic(
728+
module: PoissonNLLLossMeanReductionModule, tu: TestUtils
729+
):
730+
input = tu.rand(5, 7).abs()
731+
target = torch.poisson(input)
732+
module.forward(input, target)
733+
734+
735+
class PoissonNLLLossSumReductionModule(torch.nn.Module):
736+
def __init__(self):
737+
super().__init__()
738+
739+
@export
740+
@annotate_args(
741+
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
742+
)
743+
def forward(self, input, target):
744+
return torch.ops.aten.poisson_nll_loss(
745+
input=input,
746+
target=target,
747+
log_input=True,
748+
full=False,
749+
eps=1e-8,
750+
reduction=2,
751+
)
752+
753+
754+
@register_test_case(module_factory=lambda: PoissonNLLLossSumReductionModule())
755+
def PoissonNLLLossSumReductionModule_basic(
756+
module: PoissonNLLLossSumReductionModule, tu: TestUtils
757+
):
758+
input = tu.rand(3, 3)
759+
target = torch.poisson(input.abs())
760+
module.forward(input, target)
761+
762+
763+
class PoissonNLLLossNonDefaultEpsModule(torch.nn.Module):
764+
def __init__(self):
765+
super().__init__()
766+
767+
@export
768+
@annotate_args(
769+
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
770+
)
771+
def forward(self, input, target):
772+
return torch.ops.aten.poisson_nll_loss(
773+
input=input,
774+
target=target,
775+
log_input=False,
776+
full=False,
777+
eps=0.5,
778+
reduction=1,
779+
)
780+
781+
782+
@register_test_case(module_factory=lambda: PoissonNLLLossNonDefaultEpsModule())
783+
def PoissonNLLLossNonDefaultEpsModule_basic(
784+
module: PoissonNLLLossNonDefaultEpsModule, tu: TestUtils
785+
):
786+
input = tu.rand(5, 4)
787+
target = torch.poisson(input.abs())
788+
module.forward(input, target)

0 commit comments

Comments
 (0)