Skip to content

Commit a89021b

Browse files
authored
[mlir][spirv] Enable dot operation for bfloat16 (#145409)
Allows dot operations to use vectors of bfloat16 type.
1 parent 13bb328 commit a89021b

File tree

7 files changed

+94
-17
lines changed

7 files changed

+94
-17
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,16 +462,19 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
462462
}];
463463

464464
let arguments = (ins
465-
SPIRV_VectorOf<SPIRV_Float>:$vector1,
466-
SPIRV_VectorOf<SPIRV_Float>:$vector2
465+
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
466+
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
467467
);
468468

469469
let results = (outs
470-
SPIRV_Float:$result
470+
SPIRV_AnyFloat:$result
471471
);
472472

473473
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
474474

475+
// Require dynamic availability specification based on operand/result type.
476+
bit autogenAvailability = 0;
477+
475478
let hasVerifier = 0;
476479
}
477480

mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
77
CastOps.cpp
88
ControlFlowOps.cpp
99
CooperativeMatrixOps.cpp
10+
DotProductOps.cpp
1011
GroupOps.cpp
1112
ImageOps.cpp
12-
IntegerDotProductOps.cpp
1313
MemoryOps.cpp
1414
MeshOps.cpp
1515
SPIRVAttributes.cpp

mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp renamed to mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
1+
//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops -------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// Defines the Integer Dot Product operations in the SPIR-V dialect.
9+
// Defines the Dot Product operations in the SPIR-V dialect.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;
2121

2222
namespace mlir::spirv {
2323

24+
//===----------------------------------------------------------------------===//
25+
// Dot Product ops
26+
//===----------------------------------------------------------------------===//
27+
28+
static std::optional<spirv::Version> getDotProductMinVersion() {
29+
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
30+
}
31+
32+
static std::optional<spirv::Version> getDotProductMaxVersion() {
33+
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
34+
}
35+
36+
SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
37+
if (isa<BFloat16Type>(getType())) {
38+
static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
39+
return {extension};
40+
}
41+
42+
return {};
43+
}
44+
45+
SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
46+
if (isa<BFloat16Type>(getType())) {
47+
static const auto capability = spirv::Capability::BFloat16DotProductKHR;
48+
return {capability};
49+
}
50+
51+
return {};
52+
}
53+
54+
std::optional<spirv::Version> DotOp::getMinVersion() {
55+
return getDotProductMinVersion();
56+
}
57+
58+
std::optional<spirv::Version> DotOp::getMaxVersion() {
59+
return getDotProductMaxVersion();
60+
}
61+
2462
//===----------------------------------------------------------------------===//
2563
// Integer Dot Product ops
2664
//===----------------------------------------------------------------------===//
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
71109
return success();
72110
}
73111

74-
static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
75-
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
76-
}
77-
78-
static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
79-
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
80-
}
81-
82112
static SmallVector<ArrayRef<spirv::Extension>, 1>
83113
getIntegerDotProductExtensions() {
84114
// Requires the SPV_KHR_integer_dot_product extension, specified either
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
136166
return getIntegerDotProductCapabilities<OpName>(*this); \
137167
} \
138168
std::optional<spirv::Version> OpName::getMinVersion() { \
139-
return getIntegerDotProductMinVersion(); \
169+
return getDotProductMinVersion(); \
140170
} \
141171
std::optional<spirv::Version> OpName::getMaxVersion() { \
142-
return getIntegerDotProductMaxVersion(); \
172+
return getDotProductMaxVersion(); \
143173
}
144174

145175
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,22 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 {
967967

968968
// -----
969969

970+
module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [BFloat16DotProductKHR], [SPV_KHR_bfloat16]>, #spirv.resource_limits<>> } {
971+
972+
// CHECK-LABEL: func @reduction_bf16_addf_mulf
973+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>, %[[ARG1:.+]]: vector<4xbf16>)
974+
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xbf16> -> bf16
975+
// CHECK: return %[[DOT]] : bf16
976+
func.func @reduction_bf16_addf_mulf(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
977+
%mul = arith.mulf %arg0, %arg1 : vector<4xbf16>
978+
%red = vector.reduction <add>, %mul : vector<4xbf16> into bf16
979+
return %red : bf16
980+
}
981+
982+
} // end module
983+
984+
// -----
985+
970986
// CHECK-LABEL: @shape_cast_same_type
971987
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>)
972988
// CHECK: return %[[ARG0]]

mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
321321

322322
// -----
323323

324+
// CHECK-LABEL: @dot_bf16
325+
func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
326+
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
327+
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
328+
return %0 : bf16
329+
}
330+
331+
// -----
332+
324333
// expected-note @+1 {{prior use here}}
325334
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
326335
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
339348
// -----
340349

341350
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
342-
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
351+
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
343352
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
344353
return %0 : i32
345354
}

mlir/test/Dialect/SPIRV/IR/availability.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
234234
return %r: i64
235235
}
236236

237+
//===----------------------------------------------------------------------===//
238+
// Dot Product op with bfloat16
239+
//===----------------------------------------------------------------------===//
240+
241+
// CHECK-LABEL: dot_vector_4xbf16_bf16
242+
func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
243+
// CHECK: min version: v1.0
244+
// CHECK: max version: v1.6
245+
// CHECK: extensions: [ [SPV_KHR_bfloat16] ]
246+
// CHECK: capabilities: [ [BFloat16DotProductKHR] ]
247+
%r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
248+
return %r: bf16
249+
}
250+
237251
//===----------------------------------------------------------------------===//
238252
// Primitive ops
239253
//===----------------------------------------------------------------------===//

mlir/test/Target/SPIRV/arithmetic-ops.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
8686
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
8787
spirv.Return
8888
}
89+
spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
90+
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
91+
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
92+
spirv.Return
93+
}
8994
}

0 commit comments

Comments
 (0)