diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index ea2347edce..92e4c687ca 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -1260,6 +1260,15 @@ SpirvInstruction *SpirvEmitter::doExpr(const Expr *expr, return result; } +SpirvInstruction *SpirvEmitter::doExprEnsuringRValue(const Expr *E, + SourceLocation location, + SourceRange range) { + SpirvInstruction *I = doExpr(E); + if (I->isRValue()) + return I; + return spvBuilder.createLoad(E->getType(), I, location, range); +} + SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr, SourceRange rangeOverride) { // We are trying to load the value here, which is what an LValueToRValue @@ -11364,8 +11373,8 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMul(const CallExpr *callExpr) { uint32_t numRows = 0; if (isMxNMatrix(returnType, &elemType, &numRows)) { llvm::SmallVector rows; - auto *arg0Id = doExpr(arg0); - auto *arg1Id = doExpr(arg1); + auto *arg0Id = doExprEnsuringRValue(arg0, loc, range); + auto *arg1Id = doExprEnsuringRValue(arg1, loc, range); for (uint32_t i = 0; i < numRows; ++i) { auto *scalar = spvBuilder.createCompositeExtract(elemType, arg0Id, {i}, loc, range); @@ -11380,8 +11389,8 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMul(const CallExpr *callExpr) { } // All the following cases require handling arg0 and arg1 expressions first. - auto *arg0Id = doExpr(arg0); - auto *arg1Id = doExpr(arg1); + auto *arg0Id = doExprEnsuringRValue(arg0, loc, range); + auto *arg1Id = doExprEnsuringRValue(arg1, loc, range); // mul(scalar, scalar) if (isScalarType(arg0Type) && isScalarType(arg1Type)) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 978e88e4ed..e5daed603d 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -80,6 +80,9 @@ class SpirvEmitter : public ASTConsumer { void doDecl(const Decl *decl); void doStmt(const Stmt *stmt, llvm::ArrayRef attrs = {}); SpirvInstruction *doExpr(const Expr *expr, SourceRange rangeOverride = {}); + SpirvInstruction *doExprEnsuringRValue(const Expr *expr, + SourceLocation location, + SourceRange range); /// Processes the given expression and emits SPIR-V instructions. If the /// result is a GLValue, does an additional load. diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.mul.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.mul.hlsl index 4d04896781..629e7527c3 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.mul.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.mul.hlsl @@ -1,5 +1,8 @@ // RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s +StructuredBuffer buffer_vec; +StructuredBuffer buffer_mat; + /* According to HLSL reference, mul() has the following versions: @@ -448,6 +451,7 @@ void main() { // mul( Mat(Mx1) * Mat(1xN) ) --> Mat(MxN) matrix float1x3 mat1x3; float3x2 mat3x2; + float3x3 mat3x3; float3x1 mat3x1; float1x4 mat1x4; @@ -474,4 +478,25 @@ void main() { // CHECK-NEXT: [[result3:%[0-9]+]] = OpCompositeConstruct %mat3v4float [[row0]] [[row1]] [[row2]] // CHECK-NEXT: OpStore %result3 [[result3]] float3x4 result3 = mul( mat3x1, mat1x4 ); // result is float3x4 matrix + + float3 v3; + +// CHECK: [[matp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_mat3v3float %buffer_mat %int_0 %int_0 +// CHECK: [[mat:%[0-9]+]] = OpLoad %mat3v3float [[matp]] +// CHECK: [[vec:%[0-9]+]] = OpLoad %v3float %v3 +// CHECK: {{.*}} = OpVectorTimesMatrix %v3float [[vec]] [[mat]] + float3 result4 = mul(buffer_mat.Load(0), v3); + +// CHECK: [[mat:%[0-9]+]] = OpLoad %mat3v3float %mat3x3 +// CHECK: [[vecp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v3float %buffer_vec %int_0 %int_1 +// CHECK: [[vec:%[0-9]+]] = OpLoad %v3float [[vecp]] +// CHECK: {{.*}} = OpVectorTimesMatrix %v3float [[vec]] [[mat]] + float3 result5 = mul(mat3x3, buffer_vec.Load(1)); + +// CHECK: [[matp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_mat3v3float %buffer_mat %int_0 %int_2 +// CHECK: [[mat:%[0-9]+]] = OpLoad %mat3v3float [[matp]] +// CHECK: [[vecp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v3float %buffer_vec %int_0 %int_2 +// CHECK: [[vec:%[0-9]+]] = OpLoad %v3float [[vecp]] +// CHECK: {{.*}} = OpVectorTimesMatrix %v3float [[vec]] [[mat]] + float3 result6 = mul(buffer_mat.Load(2), buffer_vec.Load(2)); }