Skip to content

add some loss api #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -5440,6 +5440,16 @@
"divisor_override"
]
},
"torch.nn.BCELoss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.BCELoss",
"args_list": [
"weight",
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.BCEWithLogitsLoss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.BCEWithLogitsLoss",
Expand Down Expand Up @@ -5946,6 +5956,15 @@
"dtype": ""
}
},
"torch.nn.L1Loss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.L1Loss",
"args_list": [
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.LSTM": {
"Matcher": "RNNMatcher",
"paddle_api": "paddle.nn.LSTM",
Expand Down Expand Up @@ -6051,6 +6070,15 @@
"dim": "axis"
}
},
"torch.nn.MSELoss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.MSELoss",
"args_list": [
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.MaxPool1d": {
"Matcher": "MaxPoolMatcher",
"paddle_api": "paddle.nn.MaxPool1D",
Expand Down Expand Up @@ -6682,6 +6710,22 @@
"unflattened_size": "shape"
}
},
"torch.nn.Unfold": {
"Matcher": "Tuple2ListMatcher",
"paddle_api": "paddle.nn.Unfold",
"args_list": [
"kernel_size",
"dilation",
"padding",
"stride"
],
"kwargs_change": {
"kernel_size": "kernel_sizes",
"dilation": "dilations",
"padding": "paddings",
"stride": "strides"
}
},
"torch.nn.UpsamplingBilinear2d": {
"Matcher": "UpsampleMatcher",
"paddle_api": "paddle.nn.UpsamplingBilinear2D",
Expand Down Expand Up @@ -6874,6 +6918,21 @@
"input2": "x2"
}
},
"torch.nn.functional.binary_cross_entropy": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.binary_cross_entropy",
"args_list": [
"input",
"target",
"weight",
"size_average",
"reduce",
"reduction"
],
"kwargs_change": {
"target": "label"
}
},
"torch.nn.functional.binary_cross_entropy_with_logits": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.binary_cross_entropy_with_logits",
Expand Down Expand Up @@ -7764,6 +7823,24 @@
"reduction"
]
},
"torch.nn.functional.unfold": {
"Matcher": "Tuple2ListMatcher",
"paddle_api": "paddle.nn.functional.unfold",
"args_list": [
"input",
"kernel_size",
"dilation",
"padding",
"stride"
],
"kwargs_change": {
"input": "x",
"kernel_size": "kernel_sizes",
"dilation": "dilations",
"padding": "paddings",
"stride": "strides"
}
},
"torch.nn.functional.upsample": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.upsample",
Expand Down Expand Up @@ -9245,11 +9322,7 @@
"paddle_api": "paddle.zeros_like",
"args_list": [
"input",
"dtype",
"layout",
"device",
"requires_grad",
"memory_format"
"index"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改错了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修正

],
"kwargs_change": {
"input": "x"
Expand Down
21 changes: 21 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3193,6 +3193,27 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class Tuple2ListMatcher(BaseMatcher):
def generate_code(self, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

逻辑可以写成对每个kwargs遍历,判断是否kwargs,每个分支里再判断是否list,一共4个分支。用new_kwargs来接收kwargs,不然参数顺序会改变,导致代码风格不太好

for k in list(kwargs.keys()):
    if kwargs_change:
          if tuple:
          else:
    else:
          if tuple:
          else:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

new_kwargs = {}
kwargs_change = self.api_mapping["kwargs_change"]
for k in list(kwargs.keys()):
if k in kwargs_change:
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple):
new_kwargs[kwargs_change[k]] = list(ast.literal_eval(kwargs[k]))
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断前面这个就可以,直接 'list({})'.format(kwargs[k]) 吧,尽量不要用eval执行这些逻辑,容易存在隐患

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新判断逻辑

else:
new_kwargs[kwargs_change[k]] = kwargs[k]
else:
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple):
new_kwargs[k] = list(ast.literal_eval(kwargs[k]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
new_kwargs[k] = kwargs[k]

code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs))

return code


class ParameterMatcher(BaseMatcher):
def get_paddle_nodes(self, args, kwargs):
kwargs = self.parse_args_and_kwargs(args, kwargs)
Expand Down
146 changes: 146 additions & 0 deletions tests/test_nn_BCELoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.BCELoss")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,size_average=True)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,size_average=False)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduction='none')
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduction='mean')
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduction='sum')
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduce=True)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduce=False)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
loss = torch.nn.BCELoss()
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])
Loading