Skip to content

Commit 54351a5

Browse files
[SYCLomaitc][Bug][Python Migration] Added support for parsing Python syntax inside blocks like (), [] & {} (#2670)
1 parent 882a371 commit 54351a5

File tree

6 files changed

+149
-37
lines changed

6 files changed

+149
-37
lines changed

clang/lib/DPCT/MigrateScript/MigratePythonBuildScript.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ applyPythonMigrationRules(const clang::tooling::UnifiedPath InRoot,
7373
for (const auto &PythonSyntaxEntry : PythonBuildInRules) {
7474
const auto &PR = PythonSyntaxEntry.second;
7575
if (!PR.In.empty() || !PR.Out.empty()) {
76-
Buffer = applyPatternRewriter(PR, Buffer);
76+
Buffer = applyPatternRewriter(PR, Buffer, Entry.first.getPath().str(),
77+
"", OutRoot);
7778
}
7879
}
7980
}

clang/lib/DPCT/UserDefinedRules/PatternRewriter.cpp

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
#include "UserDefinedRules/PatternRewriter.h"
1010
#include "AnalysisInfo.h"
1111
#include "Diagnostics/Diagnostics.h"
12+
#include "FileGenerator/GenFiles.h"
1213
#include "MigrateScript/MigrateCmakeScript.h"
1314
#include "MigrateScript/MigratePythonBuildScript.h"
1415
#include "UserDefinedRules/UserDefinedRules.h"
15-
#include "FileGenerator/GenFiles.h"
1616

1717
#include "llvm/ADT/StringRef.h"
1818
#include "llvm/Support/Path.h"
@@ -25,7 +25,6 @@
2525
#include <variant>
2626
#include <vector>
2727

28-
2928
namespace clang {
3029
namespace dpct {
3130

@@ -62,10 +61,28 @@ static bool isWhitespace(char Character) {
6261

6362
static bool isNotWhitespace(char Character) { return !isWhitespace(Character); }
6463

64+
static bool isLeftDelimiter(char Character) {
65+
return Character == '{' || Character == '[' || Character == '(';
66+
}
67+
6568
static bool isRightDelimiter(char Character) {
6669
return Character == '}' || Character == ']' || Character == ')';
6770
}
6871

72+
static char getRightDelimiter(char Character) {
73+
char rightChar = '\0';
74+
75+
if (Character == '{') {
76+
rightChar = '}';
77+
} else if (Character == '[') {
78+
rightChar = ']';
79+
} else if (Character == '(') {
80+
rightChar = ')';
81+
}
82+
83+
return rightChar;
84+
}
85+
6986
static int detectIndentation(const std::string &Input, int Start) {
7087
int Indentation = 0;
7188
int Index = Start - 1;
@@ -239,11 +256,11 @@ findMatch(const MatchPattern &Pattern, const std::string &Input,
239256

240257
static int parseCodeElement(const MatchPattern &Suffix,
241258
const std::string &Input, const int Start,
242-
RuleMatchMode Mode);
259+
RuleMatchMode Mode, bool IsInsideBlock = false);
243260

244261
static int parseBlock(char LeftDelimiter, char RightDelimiter,
245262
const std::string &Input, const int Start,
246-
RuleMatchMode Mode) {
263+
RuleMatchMode Mode, const MatchPattern &Suffix = {}) {
247264
const int Size = Input.size();
248265
int Index = Start;
249266

@@ -252,25 +269,21 @@ static int parseBlock(char LeftDelimiter, char RightDelimiter,
252269
}
253270
Index++;
254271

255-
Index = parseCodeElement({}, Input, Index, Mode);
256-
if (Index == -1) {
257-
return -1;
258-
}
272+
Index = parseCodeElement(Suffix, Input, Index, Mode, true);
259273

260-
if (Index >= Size || Input[Index] != RightDelimiter) {
274+
if (Index == -1 || Index >= Size) {
261275
return -1;
262276
}
263-
Index++;
277+
264278
return Index;
265279
}
266280

267281
static int parseCodeElement(const MatchPattern &Suffix,
268282
const std::string &Input, const int Start,
269-
RuleMatchMode Mode) {
283+
RuleMatchMode Mode, bool IsInsideBlock) {
270284
int Index = Start;
271285
const int Size = Input.size();
272286
while (Index >= 0 && Index < Size) {
273-
274287
if (SrcFileType == SourceFileType::SFT_CMakeScript ||
275288
SrcFileType == SourceFileType::SFT_PySetupScript) {
276289
if (Input[Index] == '#') {
@@ -281,11 +294,10 @@ static int parseCodeElement(const MatchPattern &Suffix,
281294
}
282295

283296
const auto Character = Input[Index];
284-
if (SrcFileType != SourceFileType::SFT_PySetupScript) {
285-
if (Suffix.size() == 0 && Character == '"') {
286-
return Index;
287-
}
297+
if (Suffix.size() == 0 && Character == '"') {
298+
return Index;
288299
}
300+
289301
if (Suffix.size() > 0) {
290302
std::optional<MatchResult> SuffixMatch;
291303

@@ -295,24 +307,29 @@ static int parseCodeElement(const MatchPattern &Suffix,
295307
return Index;
296308
}
297309

298-
if (isRightDelimiter(Character) || Index == Size - 1) {
310+
if (Index == Size - 1) {
299311
return -1;
300312
}
301-
}
302313

303-
if (Character == '{') {
304-
Index = parseBlock('{', '}', Input, Index, Mode);
305-
continue;
314+
if (isRightDelimiter(Character)) {
315+
return IsInsideBlock ? Index : -1;
316+
}
306317
}
307318

308-
if (Character == '[') {
309-
Index = parseBlock('[', ']', Input, Index, Mode);
310-
continue;
311-
}
319+
if (isLeftDelimiter(Character)) {
320+
char RightDelimiter = getRightDelimiter(Character);
312321

313-
if (Character == '(') {
314-
Index = parseBlock('(', ')', Input, Index, Mode);
315-
continue;
322+
if (SrcFileType == SourceFileType::SFT_PySetupScript) {
323+
Index =
324+
parseBlock(Character, RightDelimiter, Input, Index, Mode, Suffix);
325+
} else {
326+
Index = parseBlock(Character, RightDelimiter, Input, Index, Mode);
327+
}
328+
if (Index != -1 && isRightDelimiter(Input[Index])) {
329+
Index++;
330+
continue;
331+
} else
332+
return Index;
316333
}
317334

318335
if (isRightDelimiter(Input[Index])) {
@@ -394,6 +411,7 @@ static int parseCodeElement(const MatchPattern &Suffix,
394411

395412
Index++;
396413
}
414+
397415
return Suffix.size() == 0 ? Index : -1;
398416
}
399417

@@ -402,24 +420,43 @@ static int parseCodeElement(const MatchPattern &Suffix,
402420
static bool isIdentifiedChar(char Char) {
403421

404422
if ((Char >= 'a' && Char <= 'z') || (Char >= 'A' && Char <= 'Z') ||
405-
(Char >= '0' && Char <= '9') || (Char == '_') || (Char == '-')) {
423+
(Char >= '0' && Char <= '9') || (Char == '_')) {
406424
return true;
425+
} else if (SrcFileType == SourceFileType::SFT_CMakeScript) {
426+
if (Char == '-')
427+
return true;
407428
}
408429

409430
return false;
410431
}
411432

433+
static bool isValidFilePrefix(char Char) {
434+
return isIdentifiedChar(Char) || Char == '.' || Char == '/' || Char == '\\';
435+
}
436+
437+
static bool isValidFilePostfix(char Char) {
438+
return !(isIdentifiedChar(Char) || Char == '.');
439+
}
440+
412441
static void applyExtenstionNameChange(
413442
const std::string &Input, size_t Next,
414443
std::unordered_map<std::string, std::string> &Bindings,
415444
const std::string &FileName, const clang::tooling::UnifiedPath &OutRoot,
416445
std::string ExtensionType) {
446+
447+
// Check for valid postfix for a file name
448+
if (!isValidFilePostfix(Input[Next + ExtensionType.length() + 1])) {
449+
Bindings["rewrite_extention_name"] = std::move(ExtensionType);
450+
return;
451+
}
452+
417453
size_t Pos = Next - 1;
418-
for (; Pos > 0 && !isWhitespace(Input[Pos]); Pos--) {
454+
// Find the starting position of the file name
455+
for (; Pos > 0 && isValidFilePrefix(Input[Pos]); Pos--) {
419456
}
420-
Pos = Pos == 0 ? 0 : Pos + 1;
421-
if (Input[Pos] == '"' || Input[Pos] == '\'')
422-
Pos += 1;
457+
if (!isValidFilePrefix(Input[Pos]))
458+
Pos++;
459+
423460
std::string SrcFile = Input.substr(Pos, Next + ExtensionType.length() +
424461
1 /*strlen of "."*/ - Pos);
425462
bool HasCudaSyntax = false;
@@ -480,6 +517,14 @@ static void applyExtenstionNameChange(
480517
HasCudaSyntax = true;
481518
break;
482519
}
520+
} else if (llvm::sys::path::filename(FileName).ends_with(".py")) {
521+
llvm::SmallString<512> _SrcFile(SrcFile);
522+
llvm::sys::path::native(_SrcFile);
523+
524+
if (llvm::StringRef(File).ends_with(_SrcFile)) {
525+
HasCudaSyntax = true;
526+
break;
527+
}
483528
} else {
484529
// For other module files (e.g., .cmake files), just check the
485530
// file names.
@@ -883,4 +928,4 @@ std::string applyPatternRewriter(const MetaRuleObject::PatternRewriter &PP,
883928
}
884929

885930
} // namespace dpct
886-
} // namespace clang
931+
} // namespace clang

clang/test/dpct/python_migration/case_007/MainSourceFiles.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,35 @@ Replacements:
1919
InitStr: ''
2020
NewHostVarName: ''
2121
BlockLevelFormatFlag: false
22+
- FilePath: '/path/to/src/bar.cpp'
23+
Offset: 0
24+
Length: 0
25+
ReplacementText: ''
26+
ConstantFlag: ''
27+
ConstantOffset: 0
28+
InitStr: ''
29+
NewHostVarName: ''
30+
BlockLevelFormatFlag: false
31+
- FilePath: '/path/to/dst/bar.cpp'
32+
Offset: 0
33+
Length: 0
34+
ReplacementText: ''
35+
ConstantFlag: ''
36+
ConstantOffset: 0
37+
InitStr: ''
38+
NewHostVarName: ''
39+
BlockLevelFormatFlag: false
2240

2341
MainSourceFilesDigest:
2442
- MainSourceFile: '/path/to/src/foo.cpp'
2543
Digest: e2636fb8d174ac319083b0306294d3bd
2644
HasCUDASyntax: true
45+
- MainSourceFile: '/path/to/src/bar.cpp'
46+
Digest: e2636fb8d174ac319083b0306294d3bd
47+
HasCUDASyntax: true
48+
- MainSourceFile: '/path/to/dst/bar.cpp'
49+
Digest: e2636fb8d174ac319083b0306294d3bd
50+
HasCUDASyntax: false
2751
- MainSourceFile: '/path/to/src/baz.cpp'
2852
Digest: 991d7e4825fc597205bf68f2eda27acd
2953
HasCUDASyntax: false
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# baz.cpp is a C++ file
22
out = func("bar.dp.cpp", "baz.cpp")
3+
34
# foo.cpp is a C++ file with CUDA syntax
45
out = func("foo.cpp.dp.cpp", "bar.dp.cpp")
5-
out = func(["foo.cpp.dp.cpp", "bar.dp.cpp"])
6+
7+
# src/bar.cpp is a C++ file with CUDA syntax
8+
# dst/bar.cpp is a C++ file
9+
out = func("src/bar.cpp.dp.cpp", "dst/bar.cpp")
10+
11+
("foo.cpp.dp.cpp", ("bar.dp.cpp", ("baz.cpp"), ("foo.cpp.dp.cpp", ), ["foo.cpp.dp.cpp"], ("foo.cpp.dp.cpp"), {"foo.cpp.dp.cpp"}, "foo.cpp.dp.cpp"))
12+
["foo.cpp.dp.cpp", ["bar.dp.cpp", ["baz.cpp"], ["foo.cpp.dp.cpp", ], ["foo.cpp.dp.cpp"], ("foo.cpp.dp.cpp"), {"foo.cpp.dp.cpp"}, "foo.cpp.dp.cpp"]]
13+
{"foo.cpp.dp.cpp", {"bar.dp.cpp", {"baz.cpp"}, {"foo.cpp.dp.cpp", }, ["foo.cpp.dp.cpp"], ("foo.cpp.dp.cpp"), {"foo.cpp.dp.cpp"}, "foo.cpp.dp.cpp"}}
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# baz.cpp is a C++ file
22
out = func("bar.cu", "baz.cpp")
3+
34
# foo.cpp is a C++ file with CUDA syntax
45
out = func("foo.cpp", "bar.cu")
5-
out = func(["foo.cpp", "bar.cu"])
6+
7+
# src/bar.cpp is a C++ file with CUDA syntax
8+
# dst/bar.cpp is a C++ file
9+
out = func("src/bar.cpp", "dst/bar.cpp")
10+
11+
("foo.cpp", ("bar.cu", ("baz.cpp"), ("foo.cpp", ), ["foo.cpp"], ("foo.cpp"), {"foo.cpp"}, "foo.cpp"))
12+
["foo.cpp", ["bar.cu", ["baz.cpp"], ["foo.cpp", ], ["foo.cpp"], ("foo.cpp"), {"foo.cpp"}, "foo.cpp"]]
13+
{"foo.cpp", {"bar.cu", {"baz.cpp"}, {"foo.cpp", }, ["foo.cpp"], ("foo.cpp"), {"foo.cpp"}, "foo.cpp"}}

clang/tools/dpct/extensions/python_rules/python_build_script_migration_rule_pytorch_common.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@
3434
In: ${arg}.cpp
3535
Out: ${arg}.${rewrite_extention_name}
3636

37+
- Rule: rule_cpp_file_in_dict
38+
Kind: PythonRule
39+
Priority: Fallback
40+
MatchMode: Partial
41+
PythonSyntax: cpp_file_in_dict
42+
In: "{${value}}"
43+
Out: "{${value}}"
44+
Subrules:
45+
value:
46+
MatchMode: Full
47+
In: ${arg}.cpp
48+
Out: ${arg}.${rewrite_extention_name}
49+
3750
- Rule: rule_cu_file_in_func
3851
Kind: PythonRule
3952
Priority: Fallback
@@ -60,6 +73,19 @@
6073
In: ${arg}.cu
6174
Out: ${arg}.${rewrite_extention_name}
6275

76+
- Rule: rule_cu_file_in_dict
77+
Kind: PythonRule
78+
Priority: Fallback
79+
MatchMode: Partial
80+
PythonSyntax: cu_file_in_dict
81+
In: "{${value}}"
82+
Out: "{${value}}"
83+
Subrules:
84+
value:
85+
MatchMode: Full
86+
In: ${arg}.cu
87+
Out: ${arg}.${rewrite_extention_name}
88+
6389
# cuda import rules
6490
- Rule: rule_from_torch_import_cuda
6591
Kind: PythonRule

0 commit comments

Comments
 (0)