@@ -455,6 +455,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
455
455
if (!operand)
456
456
return rewriter.notifyMatchFailure (
457
457
subgroupOp, " warp result is not a xegpu::LoadNd op" );
458
+ // Make sure the load op is the last operation in the warp op body. This
459
+ // ensure that load op is not sinked earlier violating any barrier
460
+ // synchronizations.
461
+ auto yield = cast<gpu::YieldOp>(
462
+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
463
+ Operation *lastNode = yield->getPrevNode ();
464
+ if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
465
+ return failure ();
458
466
459
467
auto loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
460
468
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType ();
@@ -782,6 +790,29 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
782
790
}
783
791
};
784
792
793
+ // / Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
794
+ // / region. This will simply move the barrier op outside of the warp op.
795
+ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
796
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
797
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
798
+ PatternRewriter &rewriter) const override {
799
+ auto yield = cast<gpu::YieldOp>(
800
+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
801
+ Operation *lastNode = yield->getPrevNode ();
802
+ // The last node must be a gpu::BarrierOp.
803
+ auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
804
+ if (!barrierOp)
805
+ return failure ();
806
+ // Move the barrier op outside of the warp op.
807
+ rewriter.setInsertionPointAfter (subgroupOp);
808
+ rewriter.create <gpu::BarrierOp>(
809
+ barrierOp.getLoc (), barrierOp->getResultTypes (),
810
+ barrierOp->getOperands (), barrierOp->getAttrs ());
811
+ rewriter.eraseOp (barrierOp);
812
+ return success ();
813
+ }
814
+ };
815
+
785
816
} // namespace
786
817
787
818
namespace {
@@ -796,7 +827,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
796
827
RewritePatternSet &patterns) {
797
828
patterns.add <CreateNdDescDistribution, StoreNdDistribution,
798
829
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
799
- UpdateNdOffsetDistribution>(patterns.getContext ());
830
+ UpdateNdOffsetDistribution, GpuBarrierDistribution>(
831
+ patterns.getContext ());
800
832
}
801
833
802
834
void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments