-
Notifications
You must be signed in to change notification settings - Fork 61
add some loss api #109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add some loss api #109
Conversation
paconvert/api_matcher.py
Outdated
@@ -3594,6 +3594,55 @@ def generate_code(self, kwargs): | |||
return code | |||
|
|||
|
|||
class MseLossMatcher(BaseMatcher): | |||
def generate_code(self, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这一块的重复度很高,是否可以统一成一个Matcher
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
这个文档的链接不对,不是loss的 |
PR描述里,docs直接放PR链接,不要写成软连接形式 |
paconvert/api_mapping.json
Outdated
"reduction" | ||
] | ||
}, | ||
"torch.nn.Unfold": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以用genericmatcher吧,改成那个吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel_size 参数 pytorch支持tuple,paddle不支持,改为genericmatcher遇到tuple会报错
paconvert/api_mapping.json
Outdated
"stride": "strides" | ||
} | ||
}, | ||
"torch.nn.functional.unfold": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以用genericmatcher吧,改成那个吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同torch.nn.Unfold
paconvert/api_matcher.py
Outdated
for key in list(kwargs_change.keys()): | ||
if key in kwargs: | ||
if "input" not in key: | ||
if "(" in kwargs[key]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接判断 if isinstance(kwargs[key] , ast.Tuple): 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在generate_code直接这样判断似乎不起作用
paconvert/api_matcher.py
Outdated
kwargs[kwargs_change[key]] = kwargs[key] | ||
kwargs.pop(key) | ||
|
||
if "paddings" not in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是默认值就不用单独设置了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paconvert/api_matcher.py
Outdated
kwargs_change = self.api_mapping["kwargs_change"] | ||
for key in list(kwargs_change.keys()): | ||
if key in kwargs: | ||
if "input" not in key: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个由于不可能是tuple,也不用单独判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paconvert/api_matcher.py
Outdated
@@ -3359,6 +3359,30 @@ def generate_code(self, kwargs): | |||
return GenericMatcher.generate_code(self, kwargs) | |||
|
|||
|
|||
class UnfoldMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以起一个通用的名字,这个主要功能是把tuple转成list:
可以叫Tuple2ListMatcher
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -3193,6 +3193,30 @@ def generate_code(self, kwargs): | |||
return GenericMatcher.generate_code(self, kwargs) | |||
|
|||
|
|||
class UnfoldMatcher(BaseMatcher): | |||
def generate_code(self, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
逻辑可以写成对每个kwargs遍历,判断是否kwargs,每个分支里再判断是否list,一共4个分支。用new_kwargs来接收kwargs,不然参数顺序会改变,导致代码风格不太好
for k in list(kwargs.keys()):
if kwargs_change:
if tuple:
else:
else:
if tuple:
else:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paconvert/api_mapping.json
Outdated
"device", | ||
"requires_grad", | ||
"memory_format" | ||
"index" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个改错了吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修正
paconvert/api_matcher.py
Outdated
for k in list(kwargs.keys()): | ||
if k in kwargs_change: | ||
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple): | ||
new_kwargs[kwargs_change[k]] = list(ast.literal_eval(kwargs[k])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断前面这个就可以,直接 'list({})'.format(kwargs[k]) 吧,尽量不要用eval执行这些逻辑,容易存在隐患
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新判断逻辑
paconvert/api_matcher.py
Outdated
new_kwargs[kwargs_change[k]] = kwargs[k] | ||
else: | ||
if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple): | ||
new_kwargs[k] = list(ast.literal_eval(kwargs[k])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def _test_case_7(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个什么原因不支持,需要写明白
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def _test_case_7(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个什么原因不支持,需要写明白
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Docs
PaddlePaddle/docs#5928
PR APIs