-
Notifications
You must be signed in to change notification settings - Fork 61
add some loss api #109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add some loss api #109
Changes from 10 commits
a6ff8c9
7321408
a2d50df
f9099b9
77485fd
5bf245b
f95a8f0
ae26cd1
e560dd6
7d0bc25
6035745
0bd2a86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3193,6 +3193,27 @@ def generate_code(self, kwargs): | |
return GenericMatcher.generate_code(self, kwargs) | ||
|
||
|
||
class Tuple2ListMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 逻辑可以写成对每个kwargs遍历,判断是否kwargs,每个分支里再判断是否list,一共4个分支。用new_kwargs来接收kwargs,不然参数顺序会改变,导致代码风格不太好
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
new_kwargs = {} | ||
kwargs_change = self.api_mapping["kwargs_change"] | ||
for k in list(kwargs.keys()): | ||
if k in kwargs_change: | ||
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple): | ||
new_kwargs[kwargs_change[k]] = list(ast.literal_eval(kwargs[k])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 判断前面这个就可以,直接 'list({})'.format(kwargs[k]) 吧,尽量不要用eval执行这些逻辑,容易存在隐患 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已更新判断逻辑 |
||
else: | ||
new_kwargs[kwargs_change[k]] = kwargs[k] | ||
else: | ||
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple): | ||
new_kwargs[k] = list(ast.literal_eval(kwargs[k])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
else: | ||
new_kwargs[k] = kwargs[k] | ||
|
||
code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs)) | ||
|
||
return code | ||
|
||
|
||
class ParameterMatcher(BaseMatcher): | ||
def get_paddle_nodes(self, args, kwargs): | ||
kwargs = self.parse_args_and_kwargs(args, kwargs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.nn.BCELoss") | ||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
weight = torch.tensor([0.5,0.2,0.3]) | ||
loss = torch.nn.BCELoss(weight=weight,size_average=True) | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
weight = torch.tensor([0.5,0.2,0.3]) | ||
loss = torch.nn.BCELoss(weight=weight,size_average=False) | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
weight = torch.tensor([0.5,0.2,0.3]) | ||
loss = torch.nn.BCELoss(weight=weight,reduction='none') | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
weight = torch.tensor([0.5,0.2,0.3]) | ||
loss = torch.nn.BCELoss(weight=weight,reduction='mean') | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_5(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
weight = torch.tensor([0.5,0.2,0.3]) | ||
loss = torch.nn.BCELoss(weight=weight,reduction='sum') | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_6(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
weight = torch.tensor([0.5,0.2,0.3]) | ||
loss = torch.nn.BCELoss(weight=weight,reduce=True) | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_7(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
weight = torch.tensor([0.5,0.2,0.3]) | ||
loss = torch.nn.BCELoss(weight=weight,reduce=False) | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_8(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
input = torch.tensor([[0.2837, 0.0297, 0.0355], | ||
[ 0.9112, 0.7526, 0.4061]]) | ||
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) | ||
loss = torch.nn.BCELoss() | ||
result = loss(input,target) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个改错了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修正