Skip to content

[SPIR-V] Fix r-value being used in mul intrinsic #7489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,15 @@ SpirvInstruction *SpirvEmitter::doExpr(const Expr *expr,
return result;
}

SpirvInstruction *SpirvEmitter::doExprEnsuringRValue(const Expr *E,
SourceLocation location,
SourceRange range) {
SpirvInstruction *I = doExpr(E);
if (I->isRValue())
return I;
return spvBuilder.createLoad(E->getType(), I, location, range);
}

SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
SourceRange rangeOverride) {
// We are trying to load the value here, which is what an LValueToRValue
Expand Down Expand Up @@ -11364,8 +11373,8 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMul(const CallExpr *callExpr) {
uint32_t numRows = 0;
if (isMxNMatrix(returnType, &elemType, &numRows)) {
llvm::SmallVector<SpirvInstruction *, 4> rows;
auto *arg0Id = doExpr(arg0);
auto *arg1Id = doExpr(arg1);
auto *arg0Id = doExprEnsuringRValue(arg0, loc, range);
auto *arg1Id = doExprEnsuringRValue(arg1, loc, range);
for (uint32_t i = 0; i < numRows; ++i) {
auto *scalar = spvBuilder.createCompositeExtract(elemType, arg0Id, {i},
loc, range);
Expand All @@ -11380,8 +11389,8 @@ SpirvInstruction *SpirvEmitter::processIntrinsicMul(const CallExpr *callExpr) {
}

// All the following cases require handling arg0 and arg1 expressions first.
auto *arg0Id = doExpr(arg0);
auto *arg1Id = doExpr(arg1);
auto *arg0Id = doExprEnsuringRValue(arg0, loc, range);
auto *arg1Id = doExprEnsuringRValue(arg1, loc, range);

// mul(scalar, scalar)
if (isScalarType(arg0Type) && isScalarType(arg1Type))
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class SpirvEmitter : public ASTConsumer {
void doDecl(const Decl *decl);
void doStmt(const Stmt *stmt, llvm::ArrayRef<const Attr *> attrs = {});
SpirvInstruction *doExpr(const Expr *expr, SourceRange rangeOverride = {});
SpirvInstruction *doExprEnsuringRValue(const Expr *expr,
SourceLocation location,
SourceRange range);

/// Processes the given expression and emits SPIR-V instructions. If the
/// result is a GLValue, does an additional load.
Expand Down
25 changes: 25 additions & 0 deletions tools/clang/test/CodeGenSPIRV/intrinsics.mul.hlsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// RUN: %dxc -T ps_6_0 -E main -fcgl %s -spirv | FileCheck %s

StructuredBuffer<float3> buffer_vec;
StructuredBuffer<float3x3> buffer_mat;

/*
According to HLSL reference, mul() has the following versions:

Expand Down Expand Up @@ -448,6 +451,7 @@ void main() {
// mul( Mat(Mx1) * Mat(1xN) ) --> Mat(MxN) matrix
float1x3 mat1x3;
float3x2 mat3x2;
float3x3 mat3x3;
float3x1 mat3x1;
float1x4 mat1x4;

Expand All @@ -474,4 +478,25 @@ void main() {
// CHECK-NEXT: [[result3:%[0-9]+]] = OpCompositeConstruct %mat3v4float [[row0]] [[row1]] [[row2]]
// CHECK-NEXT: OpStore %result3 [[result3]]
float3x4 result3 = mul( mat3x1, mat1x4 ); // result is float3x4 matrix

float3 v3;

// CHECK: [[matp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_mat3v3float %buffer_mat %int_0 %int_0
// CHECK: [[mat:%[0-9]+]] = OpLoad %mat3v3float [[matp]]
// CHECK: [[vec:%[0-9]+]] = OpLoad %v3float %v3
// CHECK: {{.*}} = OpVectorTimesMatrix %v3float [[vec]] [[mat]]
float3 result4 = mul(buffer_mat.Load(0), v3);

// CHECK: [[mat:%[0-9]+]] = OpLoad %mat3v3float %mat3x3
// CHECK: [[vecp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v3float %buffer_vec %int_0 %int_1
// CHECK: [[vec:%[0-9]+]] = OpLoad %v3float [[vecp]]
// CHECK: {{.*}} = OpVectorTimesMatrix %v3float [[vec]] [[mat]]
float3 result5 = mul(mat3x3, buffer_vec.Load(1));

// CHECK: [[matp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_mat3v3float %buffer_mat %int_0 %int_2
// CHECK: [[mat:%[0-9]+]] = OpLoad %mat3v3float [[matp]]
// CHECK: [[vecp:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v3float %buffer_vec %int_0 %int_2
// CHECK: [[vec:%[0-9]+]] = OpLoad %v3float [[vecp]]
// CHECK: {{.*}} = OpVectorTimesMatrix %v3float [[vec]] [[mat]]
float3 result6 = mul(buffer_mat.Load(2), buffer_vec.Load(2));
}