@@ -5903,7 +5903,6 @@ void KernelCallExpr::printSubmit(KernelPrinter &Printer) {
5903
5903
Printer << " *" << getEvent () << " = " ;
5904
5904
}
5905
5905
5906
- printStreamBase (Printer);
5907
5906
if (isDefaultStream ()) {
5908
5907
SubmitStmts.DefaultStreamFlag = true ;
5909
5908
}
@@ -5912,8 +5911,12 @@ void KernelCallExpr::printSubmit(KernelPrinter &Printer) {
5912
5911
SubmitStmts.ImplicitSyncFlag = true ;
5913
5912
}
5914
5913
if (SubmitStmts.empty ()) {
5914
+ if (ExecutionConfig.Properties .empty ()) {
5915
+ printStreamBase (Printer);
5916
+ }
5915
5917
printParallelFor (Printer, false );
5916
5918
} else {
5919
+ printStreamBase (Printer);
5917
5920
(Printer << " submit(" ).newLine ();
5918
5921
printSubmitLambda (Printer);
5919
5922
}
@@ -5945,12 +5948,20 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5945
5948
}
5946
5949
}
5947
5950
}
5951
+ bool UseEnqueueFunctions = !ExecutionConfig.Properties .empty ();
5948
5952
if (IsInSubmit) {
5949
- Printer.indent () << " cgh." ;
5953
+ Printer.indent ();
5954
+ if (!UseEnqueueFunctions) {
5955
+ Printer << " cgh." ;
5956
+ }
5950
5957
}
5951
5958
if (!SubmitStmts.NdRangeList .empty () && DpctGlobalInfo::isCommentsEnabled ())
5952
5959
Printer.line (" // run the kernel within defined ND range" );
5953
- Printer << " parallel_for" ;
5960
+ if (UseEnqueueFunctions) {
5961
+ Printer << MapNames::getExpNamespace () << " nd_launch" ;
5962
+ } else {
5963
+ Printer << " parallel_for" ;
5964
+ }
5954
5965
if (DpctGlobalInfo::isSyclNamedLambda ()) {
5955
5966
Printer << " <dpct_kernel_name<class " << getName () << " _"
5956
5967
<< LocInfo.LocHash ;
@@ -5961,16 +5972,26 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5961
5972
}
5962
5973
(Printer << " (" ).newLine ();
5963
5974
auto B = Printer.block ();
5975
+ std::unique_ptr<KernelPrinter::Block> LaunchConfigBlock;
5976
+ if (UseEnqueueFunctions) {
5977
+ (Printer.indent () << (IsInSubmit ? " cgh" : ExecutionConfig.Stream ) << " ," )
5978
+ .newLine ();
5979
+ }
5964
5980
static std::string CanIgnoreRangeStr3D =
5965
5981
DpctGlobalInfo::getCtadClass (MapNames::getClNamespace () + " range" , 3 ) +
5966
5982
" (1, 1, 1)" ;
5967
5983
static std::string CanIgnoreRangeStr1D =
5968
5984
DpctGlobalInfo::getCtadClass (MapNames::getClNamespace () + " range" , 1 ) +
5969
5985
" (1)" ;
5970
5986
if (ExecutionConfig.NdRange != " " ) {
5987
+ if (UseEnqueueFunctions) {
5988
+ Printer.line (MapNames::getExpNamespace () + " launch_config(" );
5989
+ LaunchConfigBlock = std::move (Printer.block ());
5990
+ }
5971
5991
Printer.line (ExecutionConfig.NdRange + " ," );
5972
- if (!ExecutionConfig.Properties .empty ()) {
5973
- Printer << ExecutionConfig.Properties << " , " ;
5992
+ if (UseEnqueueFunctions) {
5993
+ Printer.line (ExecutionConfig.Properties + " )," );
5994
+ LaunchConfigBlock.reset ();
5974
5995
}
5975
5996
Printer.line (" [=](" , MapNames::getClNamespace (), " nd_item<3> " ,
5976
5997
getItemName (), " )" , ExecutionConfig.SubGroupSize , " {" );
@@ -5980,6 +6001,10 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5980
6001
MemVarMap::getHeadWithoutPathCompression (
5981
6002
&(getFuncInfo ()->getVarMap ()))
5982
6003
->Dim == 1 ) {
6004
+ if (UseEnqueueFunctions) {
6005
+ Printer.line (MapNames::getExpNamespace () + " launch_config(" );
6006
+ LaunchConfigBlock = std::move (Printer.block ());
6007
+ }
5983
6008
DpctGlobalInfo::printCtadClass (Printer.indent (),
5984
6009
MapNames::getClNamespace () + " nd_range" , 1 )
5985
6010
<< " (" ;
@@ -5994,12 +6019,17 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5994
6019
Printer << " , " ;
5995
6020
Printer << ExecutionConfig.LocalSizeFor1D ;
5996
6021
(Printer << " ), " ).newLine ();
5997
- if (!ExecutionConfig.Properties .empty ()) {
5998
- Printer << ExecutionConfig.Properties << " , " ;
6022
+ if (UseEnqueueFunctions) {
6023
+ Printer.line (ExecutionConfig.Properties + " )," );
6024
+ LaunchConfigBlock.reset ();
5999
6025
}
6000
6026
Printer.line (" [=](" + MapNames::getClNamespace () + " nd_item<1> " ,
6001
6027
getItemName (), " )" , ExecutionConfig.SubGroupSize , " {" );
6002
6028
} else {
6029
+ if (UseEnqueueFunctions) {
6030
+ Printer.line (MapNames::getExpNamespace () + " launch_config(" );
6031
+ LaunchConfigBlock = std::move (Printer.block ());
6032
+ }
6003
6033
Printer.indent ();
6004
6034
Printer << MapNames::getClNamespace () + " nd_range<3>(" ;
6005
6035
if (ExecutionConfig.GroupSize == CanIgnoreRangeStr3D) {
@@ -6013,8 +6043,9 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
6013
6043
Printer << " , " ;
6014
6044
Printer << ExecutionConfig.LocalSize ;
6015
6045
(Printer << " ), " ).newLine ();
6016
- if (!ExecutionConfig.Properties .empty ()) {
6017
- Printer << ExecutionConfig.Properties << " , " ;
6046
+ if (UseEnqueueFunctions) {
6047
+ Printer.line (ExecutionConfig.Properties + " )," );
6048
+ LaunchConfigBlock.reset ();
6018
6049
}
6019
6050
Printer.line (" [=](" + MapNames::getClNamespace () + " nd_item<3> " ,
6020
6051
getItemName (), " )" , ExecutionConfig.SubGroupSize , " {" );
0 commit comments