Skip to content

Commit f39c71b

Browse files
authored
[SYCLomatic] Update Out field as optional for class type of user-defined rule if method/field are used (#2758)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent f7b1de5 commit f39c71b

File tree

8 files changed

+46
-20
lines changed

8 files changed

+46
-20
lines changed

clang/lib/DPCT/MigrateScript/MigrateCmakeScript.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,9 @@ void doCmakeScriptMigration(const clang::tooling::UnifiedPath &InRoot,
599599
}
600600

601601
void registerCmakeMigrationRule(MetaRuleObject &R) {
602-
auto PR = MetaRuleObject::PatternRewriter(R.In, R.Out, R.Subrules,
602+
if (!R.Out.has_value())
603+
return;
604+
auto PR = MetaRuleObject::PatternRewriter(R.In, R.Out.value(), R.Subrules,
603605
R.MatchMode, R.Warning, R.RuleId,
604606
R.BuildScriptSyntax, R.Priority);
605607
auto Iter = CmakeBuildInRules.find(PR.BuildScriptSyntax);

clang/lib/DPCT/MigrateScript/MigratePythonBuildScript.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ void doPythonBuildScriptMigration(const clang::tooling::UnifiedPath &InRoot,
9191
}
9292

9393
void registerPythonMigrationRule(MetaRuleObject &R) {
94-
auto PR = MetaRuleObject::PatternRewriter(R.In, R.Out, R.Subrules,
94+
if (!R.Out.has_value())
95+
return;
96+
auto PR = MetaRuleObject::PatternRewriter(R.In, R.Out.value(), R.Subrules,
9597
R.MatchMode, R.Warning, R.RuleId,
9698
R.BuildScriptSyntax, R.Priority);
9799

clang/lib/DPCT/RuleInfra/CallExprRewriter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,8 @@ class UserDefinedRewriterFactory : public CallExprRewriterFactoryBase {
17351735

17361736
public:
17371737
UserDefinedRewriterFactory(MetaRuleObject &R)
1738-
: OutStr(R.Out), Includes(R.Includes), RuleAttributes(R.RuleAttributes) {
1738+
: OutStr(R.Out.value()), Includes(R.Includes),
1739+
RuleAttributes(R.RuleAttributes) {
17391740
Priority = R.Priority;
17401741
OB.Kind = OutputBuilder::Kind::Top;
17411742
OB.RuleName = R.RuleId;

clang/lib/DPCT/RulesInclude/InclusionHeaders.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ getUserDefinedHeader(const std::string &FileName) {
5555
for (auto &Header : Rule.Includes) {
5656
PrintHeader(Header);
5757
}
58-
if (!Rule.Out.empty())
59-
PrintHeader(Rule.Out);
58+
if (Rule.Out.has_value() && !Rule.Out.value().empty())
59+
PrintHeader(Rule.Out.value());
6060
OS << Rule.Postfix;
6161
return std::make_pair(ReplHeaderStr, Rule.Priority);
6262
}

clang/lib/DPCT/UserDefinedRules/UserDefinedRules.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,30 @@ void registerMigrationRule(const std::string &Name, Functor &&F) {
4848
}
4949

5050
void registerMacroRule(MetaRuleObject &R) {
51+
if (!R.Out.has_value())
52+
return;
5153
auto It = MapNames::MacroRuleMap.find(R.In);
5254
if (It != MapNames::MacroRuleMap.end()) {
5355
if (It->second.Priority > R.Priority) {
5456
It->second.Id = R.RuleId;
5557
It->second.Priority = R.Priority;
5658
It->second.In = R.In;
57-
It->second.Out = R.Out;
59+
It->second.Out = R.Out.value();
5860
It->second.HelperFeature =
5961
clang::dpct::HelperFeatureEnum::none;
6062
It->second.Includes = R.Includes;
6163
}
6264
} else {
6365
MapNames::MacroRuleMap.emplace(
6466
R.In,
65-
MacroMigrationRule(R.RuleId, R.Priority, R.In, R.Out,
66-
clang::dpct::HelperFeatureEnum::none,
67-
R.Includes));
67+
MacroMigrationRule(R.RuleId, R.Priority, R.In, R.Out.value(),
68+
clang::dpct::HelperFeatureEnum::none, R.Includes));
6869
}
6970
}
7071

7172
void registerAPIRule(MetaRuleObject &R) {
73+
if (!R.Out.has_value())
74+
return;
7275
using namespace clang::dpct;
7376
// register rule
7477
registerMigrationRule(
@@ -119,6 +122,8 @@ void registerAPIRule(MetaRuleObject &R) {
119122
}
120123

121124
void registerHeaderRule(MetaRuleObject &R) {
125+
if (!R.Out.has_value())
126+
return;
122127
auto It = MapNames::HeaderRuleMap.find(R.In);
123128
if (It != MapNames::HeaderRuleMap.end()) {
124129
if (It->second.Priority > R.Priority) {
@@ -130,11 +135,13 @@ void registerHeaderRule(MetaRuleObject &R) {
130135
}
131136

132137
void registerTypeRule(MetaRuleObject &R) {
138+
if (!R.Out.has_value())
139+
return;
133140
std::shared_ptr TOB = std::make_shared<TypeOutputBuilder>();
134141
TOB->Kind = TypeOutputBuilder::Kind::Top;
135142
TOB->RuleName = R.RuleId;
136143
TOB->RuleFile = R.RuleFile;
137-
TOB->parse(R.Out);
144+
TOB->parse(R.Out.value());
138145

139146
if (R.RuleAttributes.NumOfTemplateArgs != -1) {
140147
dpct::TypeMatchingDesc TMD =
@@ -155,7 +162,7 @@ void registerTypeRule(MetaRuleObject &R) {
155162
auto It = MapNames::TypeNamesMap.find(R.In);
156163
if (It != MapNames::TypeNamesMap.end()) {
157164
if (It->second->Priority > R.Priority) {
158-
It->second->NewName = R.Out;
165+
It->second->NewName = R.Out.value();
159166
It->second->Priority = R.Priority;
160167
It->second->RequestFeature =
161168
clang::dpct::HelperFeatureEnum::none;
@@ -167,7 +174,7 @@ void registerTypeRule(MetaRuleObject &R) {
167174
return std::make_unique<clang::dpct::UserDefinedTypeRule>(In);
168175
});
169176
auto RulePtr = std::make_shared<TypeNameRule>(
170-
R.Out, clang::dpct::HelperFeatureEnum::none, R.Priority);
177+
R.Out.value(), clang::dpct::HelperFeatureEnum::none, R.Priority);
171178
RulePtr->Includes.insert(RulePtr->Includes.end(), R.Includes.begin(),
172179
R.Includes.end());
173180
MapNames::TypeNamesMap.emplace(R.In, RulePtr);
@@ -176,7 +183,8 @@ void registerTypeRule(MetaRuleObject &R) {
176183

177184
void registerClassRule(MetaRuleObject &R) {
178185
// register class name migration rule
179-
registerTypeRule(R);
186+
if (R.Out.has_value())
187+
registerTypeRule(R);
180188
// register all field rules
181189
for (auto ItField = R.Fields.begin(); ItField != R.Fields.end(); ItField++) {
182190
std::string BaseAndFieldName = R.In + "." + (*ItField)->In;
@@ -250,11 +258,13 @@ void registerClassRule(MetaRuleObject &R) {
250258
}
251259

252260
void registerEnumRule(MetaRuleObject &R) {
261+
if (!R.Out.has_value())
262+
return;
253263
auto It = MapNames::EnumNamesMap.find(R.In);
254264
if (It != MapNames::EnumNamesMap.end()) {
255265
if (It->second->Priority > R.Priority) {
256266
It->second->Priority = R.Priority;
257-
It->second->NewName = R.Out;
267+
It->second->NewName = R.Out.value();
258268
It->second->RequestFeature =
259269
clang::dpct::HelperFeatureEnum::none;
260270
It->second->Includes.insert(It->second->Includes.end(),
@@ -268,7 +278,7 @@ void registerEnumRule(MetaRuleObject &R) {
268278
return std::make_unique<clang::dpct::UserDefinedEnumRule>(Enum);
269279
});
270280
auto RulePtr = std::make_shared<EnumNameRule>(
271-
R.Out, clang::dpct::HelperFeatureEnum::none, R.Priority);
281+
R.Out.value(), clang::dpct::HelperFeatureEnum::none, R.Priority);
272282
RulePtr->Includes.insert(RulePtr->Includes.end(), R.Includes.begin(),
273283
R.Includes.end());
274284
MapNames::EnumNamesMap.emplace(
@@ -277,17 +287,23 @@ void registerEnumRule(MetaRuleObject &R) {
277287
}
278288

279289
void deregisterAPIRule(MetaRuleObject &R) {
290+
if (!R.Out.has_value())
291+
return;
280292
using namespace clang::dpct;
281293
CallExprRewriterFactoryBase::RewriterMap->erase(R.In);
282294
}
283295

284296
void registerPatternRewriterRule(MetaRuleObject &R) {
297+
if (!R.Out.has_value())
298+
return;
285299
MapNames::PatternRewriters.emplace_back(MetaRuleObject::PatternRewriter(
286-
R.In, R.Out, R.Subrules, R.MatchMode, R.Warning, R.RuleId,
300+
R.In, R.Out.value(), R.Subrules, R.MatchMode, R.Warning, R.RuleId,
287301
R.BuildScriptSyntax, R.Priority));
288302
}
289303

290304
void registerHelperFunctionRule(MetaRuleObject &R) {
305+
if (!R.Out.has_value())
306+
return;
291307
static const std::unordered_map<std::string, dpct::HelperFuncCatalog>
292308
String2HelperFuncCatalogMap{
293309
{"get_default_queue", dpct::HelperFuncCatalog::GetDefaultQueue},
@@ -300,7 +316,7 @@ void registerHelperFunctionRule(MetaRuleObject &R) {
300316
// This map is inited here.
301317
// It saves the customized string which used for each kind of helper
302318
// function call in the migrated code.
303-
MapNames::CustomHelperFunctionMap.insert({Iter->second, R.Out});
319+
MapNames::CustomHelperFunctionMap.insert({Iter->second, R.Out.value()});
304320
dpct::DpctGlobalInfo::setUsingDRYPattern(false);
305321
dpct::DpctGlobalInfo::getCustomHelperFunctionAddtionalIncludes().insert(
306322
R.Includes.begin(), R.Includes.end());

clang/lib/DPCT/UserDefinedRules/UserDefinedRules.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class MetaRuleObject {
130130
std::string BuildScriptSyntax;
131131
RuleKind Kind;
132132
std::string In;
133-
std::string Out;
133+
std::optional<std::string> Out = std::nullopt;
134134
std::string EnumName;
135135
std::string Prefix;
136136
std::string Postfix;
@@ -319,7 +319,7 @@ template <> struct llvm::yaml::MappingTraits<std::shared_ptr<MetaRuleObject>> {
319319
Io.mapOptional("CmakeSyntax", Doc->BuildScriptSyntax);
320320
Io.mapOptional("PythonSyntax", Doc->BuildScriptSyntax);
321321
Io.mapRequired("In", Doc->In);
322-
Io.mapRequired("Out", Doc->Out);
322+
Io.mapOptional("Out", Doc->Out);
323323
Io.mapOptional("Includes", Doc->Includes);
324324
Io.mapOptional("Fields", Doc->Fields);
325325
Io.mapOptional("Methods", Doc->Methods);

clang/test/dpct/pytorch/torch.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ void foo(torch::Tensor x) {
2525
// CHECK: MY_CHECK(x.is_xpu(), "x must reside on device");
2626
MY_CHECK(x.is_cuda(), "x must reside on device");
2727
}
28+
29+
// void foo2(at::Tensor x) {
30+
void foo2(at::Tensor x) {
31+
// CHECK: MY_CHECK(x.is_xpu(), "x must reside on device");
32+
MY_CHECK(x.is_cuda(), "x must reside on device");
33+
}

clang/tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@
156156
Kind: Class
157157
Priority: Takeover
158158
In: at::Tensor # The underlying type of torch::Tensor is at::Tensor
159-
Out: torch::Tensor
160159
Methods:
161160
- In: is_cuda
162161
Out: $method_base is_xpu()

0 commit comments

Comments
 (0)