@@ -5907,7 +5907,6 @@ void KernelCallExpr::printSubmit(KernelPrinter &Printer) {
5907
5907
Printer << " *" << getEvent () << " = " ;
5908
5908
}
5909
5909
5910
- printStreamBase (Printer);
5911
5910
if (isDefaultStream ()) {
5912
5911
SubmitStmts.DefaultStreamFlag = true ;
5913
5912
}
@@ -5916,8 +5915,12 @@ void KernelCallExpr::printSubmit(KernelPrinter &Printer) {
5916
5915
SubmitStmts.ImplicitSyncFlag = true ;
5917
5916
}
5918
5917
if (SubmitStmts.empty ()) {
5918
+ if (ExecutionConfig.Properties .empty ()) {
5919
+ printStreamBase (Printer);
5920
+ }
5919
5921
printParallelFor (Printer, false );
5920
5922
} else {
5923
+ printStreamBase (Printer);
5921
5924
(Printer << " submit(" ).newLine ();
5922
5925
printSubmitLambda (Printer);
5923
5926
}
@@ -5949,12 +5952,20 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5949
5952
}
5950
5953
}
5951
5954
}
5955
+ bool UseEnqueueFunctions = !ExecutionConfig.Properties .empty ();
5952
5956
if (IsInSubmit) {
5953
- Printer.indent () << " cgh." ;
5957
+ Printer.indent ();
5958
+ if (!UseEnqueueFunctions) {
5959
+ Printer << " cgh." ;
5960
+ }
5954
5961
}
5955
5962
if (!SubmitStmts.NdRangeList .empty () && DpctGlobalInfo::isCommentsEnabled ())
5956
5963
Printer.line (" // run the kernel within defined ND range" );
5957
- Printer << " parallel_for" ;
5964
+ if (UseEnqueueFunctions) {
5965
+ Printer << MapNames::getExpNamespace ();
5966
+ }
5967
+ // Printer << "parallel_for";
5968
+ Printer << " nd_launch" ;
5958
5969
if (DpctGlobalInfo::isSyclNamedLambda ()) {
5959
5970
Printer << " <dpct_kernel_name<class " << getName () << " _"
5960
5971
<< LocInfo.LocHash ;
@@ -5965,16 +5976,26 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5965
5976
}
5966
5977
(Printer << " (" ).newLine ();
5967
5978
auto B = Printer.block ();
5979
+ std::unique_ptr<KernelPrinter::Block> LaunchConfigBlock;
5980
+ if (UseEnqueueFunctions) {
5981
+ (Printer.indent () << (IsInSubmit ? " cgh" : ExecutionConfig.Stream ) << " ," )
5982
+ .newLine ();
5983
+ }
5968
5984
static std::string CanIgnoreRangeStr3D =
5969
5985
DpctGlobalInfo::getCtadClass (MapNames::getClNamespace () + " range" , 3 ) +
5970
5986
" (1, 1, 1)" ;
5971
5987
static std::string CanIgnoreRangeStr1D =
5972
5988
DpctGlobalInfo::getCtadClass (MapNames::getClNamespace () + " range" , 1 ) +
5973
5989
" (1)" ;
5974
5990
if (ExecutionConfig.NdRange != " " ) {
5991
+ if (UseEnqueueFunctions) {
5992
+ Printer.line (MapNames::getExpNamespace () + " launch_config(" );
5993
+ LaunchConfigBlock = std::move (Printer.block ());
5994
+ }
5975
5995
Printer.line (ExecutionConfig.NdRange + " ," );
5976
- if (!ExecutionConfig.Properties .empty ()) {
5977
- Printer << ExecutionConfig.Properties << " , " ;
5996
+ if (UseEnqueueFunctions) {
5997
+ Printer.line (ExecutionConfig.Properties + " )," );
5998
+ LaunchConfigBlock.reset ();
5978
5999
}
5979
6000
Printer.line (" [=](" , MapNames::getClNamespace (), " nd_item<3> " ,
5980
6001
getItemName (), " )" , ExecutionConfig.SubGroupSize , " {" );
@@ -5984,6 +6005,10 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5984
6005
MemVarMap::getHeadWithoutPathCompression (
5985
6006
&(getFuncInfo ()->getVarMap ()))
5986
6007
->Dim == 1 ) {
6008
+ if (UseEnqueueFunctions) {
6009
+ Printer.line (MapNames::getExpNamespace () + " launch_config(" );
6010
+ LaunchConfigBlock = std::move (Printer.block ());
6011
+ }
5987
6012
DpctGlobalInfo::printCtadClass (Printer.indent (),
5988
6013
MapNames::getClNamespace () + " nd_range" , 1 )
5989
6014
<< " (" ;
@@ -5998,12 +6023,17 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
5998
6023
Printer << " , " ;
5999
6024
Printer << ExecutionConfig.LocalSizeFor1D ;
6000
6025
(Printer << " ), " ).newLine ();
6001
- if (!ExecutionConfig.Properties .empty ()) {
6002
- Printer << ExecutionConfig.Properties << " , " ;
6026
+ if (UseEnqueueFunctions) {
6027
+ Printer.line (ExecutionConfig.Properties + " )," );
6028
+ LaunchConfigBlock.reset ();
6003
6029
}
6004
6030
Printer.line (" [=](" + MapNames::getClNamespace () + " nd_item<1> " ,
6005
6031
getItemName (), " )" , ExecutionConfig.SubGroupSize , " {" );
6006
6032
} else {
6033
+ if (UseEnqueueFunctions) {
6034
+ Printer.line (MapNames::getExpNamespace () + " launch_config(" );
6035
+ LaunchConfigBlock = std::move (Printer.block ());
6036
+ }
6007
6037
Printer.indent ();
6008
6038
Printer << MapNames::getClNamespace () + " nd_range<3>(" ;
6009
6039
if (ExecutionConfig.GroupSize == CanIgnoreRangeStr3D) {
@@ -6017,8 +6047,9 @@ void KernelCallExpr::printParallelFor(KernelPrinter &Printer, bool IsInSubmit) {
6017
6047
Printer << " , " ;
6018
6048
Printer << ExecutionConfig.LocalSize ;
6019
6049
(Printer << " ), " ).newLine ();
6020
- if (!ExecutionConfig.Properties .empty ()) {
6021
- Printer << ExecutionConfig.Properties << " , " ;
6050
+ if (UseEnqueueFunctions) {
6051
+ Printer.line (ExecutionConfig.Properties + " )," );
6052
+ LaunchConfigBlock.reset ();
6022
6053
}
6023
6054
Printer.line (" [=](" + MapNames::getClNamespace () + " nd_item<3> " ,
6024
6055
getItemName (), " )" , ExecutionConfig.SubGroupSize , " {" );
0 commit comments