Skip to content

Commit 8a79a1c

Browse files
authored
新增MultiLabelMarginLoss (#73538)
1 parent 2168d8b commit 8a79a1c

File tree

6 files changed

+616
-0
lines changed

6 files changed

+616
-0
lines changed

python/paddle/nn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
L1Loss,
109109
MarginRankingLoss,
110110
MSELoss,
111+
MultiLabelMarginLoss,
111112
MultiLabelSoftMarginLoss,
112113
MultiMarginLoss,
113114
NLLLoss,
@@ -305,6 +306,7 @@
305306
'CosineEmbeddingLoss',
306307
'RReLU',
307308
'MultiMarginLoss',
309+
'MultiLabelMarginLoss',
308310
'TripletMarginWithDistanceLoss',
309311
'TripletMarginLoss',
310312
'SoftMarginLoss',

python/paddle/nn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
margin_cross_entropy,
117117
margin_ranking_loss,
118118
mse_loss,
119+
multi_label_margin_loss,
119120
multi_label_soft_margin_loss,
120121
multi_margin_loss,
121122
nll_loss,
@@ -292,6 +293,7 @@
292293
'triplet_margin_loss',
293294
'adaptive_log_softmax_with_loss',
294295
'multi_margin_loss',
296+
'multi_label_margin_loss',
295297
'soft_margin_loss',
296298
'gaussian_nll_loss',
297299
'scaled_dot_product_attention',

python/paddle/nn/functional/loss.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4190,6 +4190,125 @@ def multi_margin_loss(
41904190
return loss
41914191

41924192

4193+
def multi_label_margin_loss(
4194+
input: Tensor,
4195+
label: Tensor,
4196+
reduction: _ReduceMode = 'mean',
4197+
name: str | None = None,
4198+
) -> Tensor:
4199+
r"""Measures a multi-class multi-classification hinge loss (margin-based loss) between input :math:`input` and label :math:`label`:
4200+
4201+
For i-th mini-batch sample, the loss in terms of the 2D input :math:`input_i` and 2D label :math:`label_i` is:
4202+
4203+
.. math::
4204+
\text{loss}(input_i, label_i) = \frac{\sum_{j \in \text{valid_labels}} \sum_{k \neq \text{valid_labels}} \max(0, 1 - (input_i[\text{valid_labels}[j]] - input_i[k]))}{C}
4205+
4206+
where :math:`C` is the number of classes, :math:`\text{valid_labels}` contains all non-negative label indices
4207+
for sample :math:`i` (stopping at the first -1 encountered), and :math:`k` ranges over all class indices
4208+
except those in :math:`\text{valid_labels}`.
4209+
4210+
The criterion only considers the first non-negative label values, allowing different samples to have variable numbers of target classes.
4211+
4212+
Parameters:
4213+
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes.
4214+
label (Tensor): Label tensor, the data type is int32 or int64. Shape is (N, C), same shape as input.
4215+
Label values should be class indices (non-negative values) and -1 values.
4216+
The -1 values are ignored and stop processing for each sample.
4217+
reduction (str, optional): Indicate how to calculate the loss by batch_size,
4218+
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
4219+
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
4220+
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
4221+
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
4222+
Default: ``'mean'``
4223+
name (str|None, optional): Name for the operation (optional, default is None).
4224+
For more information, please refer to :ref:`api_guide_Name`.
4225+
4226+
Returns:
4227+
Tensor, The tensor variable storing the multi_label_margin_loss of input and label.
4228+
4229+
Examples:
4230+
.. code-block:: python
4231+
4232+
>>> import paddle
4233+
>>> import paddle.nn.functional as F
4234+
4235+
>>> input = paddle.to_tensor([[0.1, 0.2, 0.4, 0.8], [0.2, 0.5, 0.3, 0.1]], dtype='float32')
4236+
>>> label = paddle.to_tensor([[3, 0, -1, -1], [0, 2, -1, -1]], dtype='int64')
4237+
4238+
>>> loss = F.multi_label_margin_loss(input, label, reduction='mean')
4239+
>>> print(loss)
4240+
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
4241+
0.94999999)
4242+
"""
4243+
if reduction not in ['sum', 'mean', 'none']:
4244+
raise ValueError(
4245+
"'reduction' in 'multi_label_margin_loss' should be 'sum', 'mean' or 'none', "
4246+
f"but received {reduction}."
4247+
)
4248+
4249+
if not in_dynamic_mode():
4250+
check_variable_and_dtype(
4251+
input, 'input', ['float32', 'float64'], 'multi_label_margin_loss'
4252+
)
4253+
check_variable_and_dtype(
4254+
label, 'label', ['int32', 'int64'], 'multi_label_margin_loss'
4255+
)
4256+
4257+
if input.dim() != 2:
4258+
raise ValueError(f'Expected 2D input tensor, but got {input.dim()}D')
4259+
4260+
if label.dim() != 2:
4261+
raise ValueError(f'Expected 2D label tensor, but got {label.dim()}D')
4262+
4263+
N, C = input.shape
4264+
4265+
if paddle.in_dynamic_mode() and label.numel() > 0:
4266+
min_val = paddle.min(label).item()
4267+
max_val = paddle.max(label).item()
4268+
4269+
if min_val < -1:
4270+
raise ValueError("label values should be >= -1")
4271+
if max_val >= C:
4272+
raise ValueError(f"label values should be < {C}")
4273+
4274+
# calculate valid_mask
4275+
valid_mask = (label != -1).cast('int32')
4276+
valid_mask = valid_mask * valid_mask.cumprod(dim=1)
4277+
4278+
row_ids, col_ids = paddle.where(valid_mask)
4279+
targets_flat = label[row_ids, col_ids]
4280+
4281+
invalid_mask = paddle.ones([N, C], dtype='bool')
4282+
invalid_mask[row_ids, targets_flat] = False
4283+
4284+
# calculate margin by broadcasting
4285+
input_target = input[row_ids, targets_flat].unsqueeze(-1)
4286+
margin = 1 - input_target + input[row_ids]
4287+
margin = paddle.where(
4288+
invalid_mask[row_ids], margin, paddle.zeros_like(margin)
4289+
)
4290+
4291+
relu_margin = paddle.maximum(margin, paddle.zeros_like(margin))
4292+
4293+
losses = paddle.scatter_nd_add(
4294+
paddle.zeros([N], dtype=input.dtype),
4295+
row_ids.unsqueeze(-1),
4296+
relu_margin.sum(
4297+
axis=1,
4298+
),
4299+
)
4300+
4301+
# average by number of valid labels
4302+
losses /= C
4303+
4304+
if reduction == 'mean':
4305+
return paddle.mean(losses, name=name)
4306+
elif reduction == 'sum':
4307+
return paddle.sum(losses, name=name)
4308+
elif reduction == 'none':
4309+
return losses
4310+
4311+
41934312
def soft_margin_loss(
41944313
input: Tensor,
41954314
label: Tensor,

python/paddle/nn/layer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
L1Loss,
7373
MarginRankingLoss,
7474
MSELoss,
75+
MultiLabelMarginLoss,
7576
MultiLabelSoftMarginLoss,
7677
MultiMarginLoss,
7778
NLLLoss,

python/paddle/nn/layer/loss.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,92 @@ def forward(self, input: Tensor, label: Tensor) -> Tensor:
21972197
)
21982198

21992199

2200+
class MultiLabelMarginLoss(Layer):
2201+
r"""Creates a criterion that optimizes a multi-class multi-classification hinge loss (margin-based loss)
2202+
between input :math:`input` and label :math:`label`:
2203+
2204+
For i-th mini-batch sample, the loss in terms of the 2D input :math:`input_i` and 2D label :math:`label_i` is:
2205+
2206+
.. math::
2207+
\text{loss}(input_i, label_i) = \frac{\sum_{j \in \text{valid_labels}} \sum_{k \neq \text{valid_labels}} \max(0, 1 - (input_i[\text{valid_labels}[j]] - input_i[k]))}{C}
2208+
2209+
where :math:`C` is the number of classes, :math:`\text{valid_labels}` contains all non-negative label indices
2210+
for sample :math:`i` (stopping at the first -1 encountered), and :math:`k` ranges over all class indices
2211+
except those in :math:`\text{valid_labels}`.
2212+
2213+
The criterion only considers the first non-negative label values, allowing different samples to have variable numbers of target classes.
2214+
2215+
Parameters:
2216+
2217+
reduction (str, optional): Indicate how to calculate the loss by batch_size,
2218+
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
2219+
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
2220+
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
2221+
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
2222+
Default: ``'mean'``
2223+
2224+
name (str|None, optional): Name for the operation (optional, default is None).
2225+
For more information, please refer to :ref:`api_guide_Name`.
2226+
2227+
Call parameters:
2228+
input (Tensor): Input tensor, the data type is float32 or float64.
2229+
2230+
label (Tensor): Label tensor, the data type is int32 or int64.
2231+
Label values should be class indices (non-negative values) and -1 values.
2232+
The -1 values are ignored and stop processing for each sample.
2233+
2234+
Shape:
2235+
input: 2-D Tensor, the shape is :math:`[N, C]`, where :math:`N` is batch size and :math:`C` is number of classes.
2236+
2237+
label: 2-D Tensor, the shape is :math:`[N, C]`, same shape as input.
2238+
2239+
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as :math:`[N]`.
2240+
2241+
Returns:
2242+
A callable object of MultiLabelMarginLoss.
2243+
2244+
Examples:
2245+
.. code-block:: python
2246+
2247+
>>> import paddle
2248+
>>> import paddle.nn as nn
2249+
2250+
>>> input = paddle.to_tensor([[0.1, 0.2, 0.4, 0.8], [0.2, 0.5, 0.3, 0.1]], dtype='float32')
2251+
>>> label = paddle.to_tensor([[3, 0, -1, -1], [0, 2, -1, -1]], dtype='int64')
2252+
2253+
>>> multi_label_margin_loss = nn.MultiLabelMarginLoss(reduction='mean')
2254+
>>> loss = multi_label_margin_loss(input, label)
2255+
>>> print(loss)
2256+
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
2257+
0.94999999)
2258+
"""
2259+
2260+
reduction: _ReduceMode
2261+
name: str | None
2262+
2263+
def __init__(
2264+
self,
2265+
reduction: _ReduceMode = 'mean',
2266+
name: str | None = None,
2267+
) -> None:
2268+
super().__init__()
2269+
if reduction not in ['sum', 'mean', 'none']:
2270+
raise ValueError(
2271+
"'reduction' in 'MultiLabelMarginLoss' should be 'sum', 'mean' or 'none', "
2272+
f"but received {reduction}."
2273+
)
2274+
self.reduction = reduction
2275+
self.name = name
2276+
2277+
def forward(self, input: Tensor, label: Tensor) -> Tensor:
2278+
return F.multi_label_margin_loss(
2279+
input,
2280+
label,
2281+
reduction=self.reduction,
2282+
name=self.name,
2283+
)
2284+
2285+
22002286
class SoftMarginLoss(Layer):
22012287
r"""
22022288

0 commit comments

Comments
 (0)