diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index b99ed261ecfa3..aa2528af30312 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -385,6 +385,9 @@ LogicalResult GPUModuleConversion::matchAndRewrite( if (auto attr = moduleOp->getAttrOfType( spirv::getTargetEnvAttrName())) spvModule->setAttr(spirv::getTargetEnvAttrName(), attr); + for (const Attribute &targetAttr : moduleOp.getTargetsAttr()) + if (auto spirvTargetEnvAttr = dyn_cast(targetAttr)) + spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr); rewriter.eraseOp(moduleOp); return success(); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index 0b2c06a08db2d..29960791e3a4b 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -48,9 +48,34 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase { void runOnOperation() override; private: + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp`. + spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp); + + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp` or returns target environment as returned by + /// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'. + spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp); bool mapMemorySpace; }; +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) { + for (const Attribute &targetAttr : moduleOp.getTargetsAttr()) + if (auto spirvTargetEnvAttr = dyn_cast(targetAttr)) + return spirvTargetEnvAttr; + + return {}; +} + +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) { + if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp)) + return targetEnvAttr; + + return spirv::lookupTargetEnvOrDefault(moduleOp); +} + void GPUToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -58,9 +83,8 @@ void GPUToSPIRVPass::runOnOperation() { SmallVector gpuModules; OpBuilder builder(context); - auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) { - Operation *gpuModule = moduleOp.getOperation(); - auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule); + auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) { + auto targetAttr = lookupTargetEnvOrDefault(moduleOp); spirv::TargetEnv targetEnv(targetAttr); return targetEnv.allows(spirv::Capability::Kernel); }; @@ -86,7 +110,7 @@ void GPUToSPIRVPass::runOnOperation() { // TargetEnv attributes. for (Operation *gpuModule : gpuModules) { spirv::TargetEnvAttr targetAttr = - spirv::lookupTargetEnvOrDefault(gpuModule); + lookupTargetEnvOrDefault(cast(gpuModule)); // Map MemRef memory space to SPIR-V storage class first if requested. if (mapMemorySpace) { diff --git a/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir new file mode 100644 index 0000000000000..8efda39d9c37e --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt --convert-gpu-to-spirv %s | FileCheck %s + +module attributes {gpu.container_module} { + // CHECK-LABEL: spirv.module @{{.*}} GLSL450 + gpu.module @kernels [#spirv.target_env<#spirv.vce, #spirv.resource_limits<>>] { + // CHECK: spirv.func @load_kernel + // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %c0 = arith.constant 0 : index + // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}} + // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32 + %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32> + // CHECK: spirv.Return + gpu.return + } + } +}