Skip to content

add some loss api #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 5, 2023
Merged

add some loss api #109

merged 12 commits into from
Jul 5, 2023

Conversation

LokeZhou
Copy link
Contributor

@LokeZhou LokeZhou commented Jun 19, 2023

PR Docs

PaddlePaddle/docs#5928

PR APIs

torch.nn.MSELoss
torch.nn.functional.binary_cross_entropy
torch.nn.BCELoss
torch.nn.L1Loss
torch.nn.Unfold
torch.nn.functional.unfold

@@ -3594,6 +3594,55 @@ def generate_code(self, kwargs):
return code


class MseLossMatcher(BaseMatcher):
def generate_code(self, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这一块的重复度很高,是否可以统一成一个Matcher

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@zhwesky2010
Copy link
Collaborator

PaddlePaddle/docs#5928

PR Docs

PaddlePaddle/docs#5928

PR APIs

torch.nn.MSELoss
torch.nn.functional.binary_cross_entropy
torch.nn.BCELoss
torch.nn.L1Loss
torch.nn.Unfold
torch.nn.functional.unfold

这个文档的链接不对,不是loss的

@LokeZhou LokeZhou requested a review from zhwesky2010 June 21, 2023 06:12
@zhwesky2010
Copy link
Collaborator

zhwesky2010 commented Jun 26, 2023

PR描述里,docs直接放PR链接,不要写成软连接形式

"reduction"
]
},
"torch.nn.Unfold": {
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以用genericmatcher吧,改成那个吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_size 参数 pytorch支持tuple,paddle不支持,改为genericmatcher遇到tuple会报错

"stride": "strides"
}
},
"torch.nn.functional.unfold": {
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以用genericmatcher吧,改成那个吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同torch.nn.Unfold

for key in list(kwargs_change.keys()):
if key in kwargs:
if "input" not in key:
if "(" in kwargs[key]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接判断 if isinstance(kwargs[key] , ast.Tuple): 吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在generate_code直接这样判断似乎不起作用

kwargs[kwargs_change[key]] = kwargs[key]
kwargs.pop(key)

if "paddings" not in kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是默认值就不用单独设置了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

kwargs_change = self.api_mapping["kwargs_change"]
for key in list(kwargs_change.keys()):
if key in kwargs:
if "input" not in key:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个由于不可能是tuple,也不用单独判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -3359,6 +3359,30 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class UnfoldMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以起一个通用的名字,这个主要功能是把tuple转成list:
可以叫Tuple2ListMatcher

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -3193,6 +3193,30 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class UnfoldMatcher(BaseMatcher):
def generate_code(self, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

逻辑可以写成对每个kwargs遍历,判断是否kwargs,每个分支里再判断是否list,一共4个分支。用new_kwargs来接收kwargs,不然参数顺序会改变,导致代码风格不太好

for k in list(kwargs.keys()):
    if kwargs_change:
          if tuple:
          else:
    else:
          if tuple:
          else:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"device",
"requires_grad",
"memory_format"
"index"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改错了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修正

for k in list(kwargs.keys()):
if k in kwargs_change:
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple):
new_kwargs[kwargs_change[k]] = list(ast.literal_eval(kwargs[k]))
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断前面这个就可以,直接 'list({})'.format(kwargs[k]) 吧,尽量不要用eval执行这些逻辑,容易存在隐患

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新判断逻辑

new_kwargs[kwargs_change[k]] = kwargs[k]
else:
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple):
new_kwargs[k] = list(ast.literal_eval(kwargs[k]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

obj.run(pytorch_code, ["result"])


def _test_case_7():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个什么原因不支持,需要写明白

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

obj.run(pytorch_code, ["result"])


def _test_case_7():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个什么原因不支持,需要写明白

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

zhwesky2010
zhwesky2010 previously approved these changes Jul 5, 2023
Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit 182bd89 into PaddlePaddle:master Jul 5, 2023
@LokeZhou LokeZhou deleted the loss branch July 5, 2023 08:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants