-
Notifications
You must be signed in to change notification settings - Fork 758
Closed
Description
In the case below, the broadcast should be fused with the attention op instead of the producer (bit-extend) operation.
util.func public @attention_broadcast(
%arg0 : tensor<32x16x?x128xf16>,
%arg1 : tensor<4x?x8x128xf8E4M3FN>,
%arg2 : tensor<4x?x8x128xf8E4M3FN>,
%arg3 : f16,
%arg4 : tensor<32x16x?x?xf16>) -> (tensor<32x16x?x128xf16>) {
%cst = arith.constant 1 : index
%dim = tensor.dim %arg1, %cst : tensor<4x?x8x128xf8E4M3FN>
%empty1 = tensor.empty(%dim) : tensor<4x?x8x128xf16>
%empty2 = tensor.empty(%dim) : tensor<4x8x16x?x128xf16>
%empty3 = tensor.empty(%dim) : tensor<4x8x16x128x?xf16>
%empty4 = tensor.empty(%dim) : tensor<32x16x?x128xf16>
%k = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg1 : tensor<4x?x8x128xf8E4M3FN>)
outs(%empty1 : tensor<4x?x8x128xf16>){
^bb0(%in: f8E4M3FN, %out: f16):
%extf = arith.extf %in : f8E4M3FN to f16
linalg.yield %extf : f16
} -> tensor<4x?x8x128xf16>
%k_bcast = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%k : tensor<4x?x8x128xf16>)
outs(%empty2 : tensor<4x8x16x?x128xf16>){
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<4x8x16x?x128xf16>
%k_collapse = tensor.collapse_shape %k_bcast [[0, 1], [2], [3], [4]] : tensor<4x8x16x?x128xf16> into tensor<32x16x?x128xf16>
%v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg2 : tensor<4x?x8x128xf8E4M3FN>)
outs(%empty1 : tensor<4x?x8x128xf16>){
^bb0(%in: f8E4M3FN, %out: f16):
%extf = arith.extf %in : f8E4M3FN to f16
linalg.yield %extf : f16
} -> tensor<4x?x8x128xf16>
%v_bcast = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%v : tensor<4x?x8x128xf16>)
outs(%empty3 : tensor<4x8x16x128x?xf16>){
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<4x8x16x128x?xf16>
%v_collapse = tensor.collapse_shape %v_bcast [[0, 1], [2], [3], [4]] : tensor<4x8x16x128x?xf16> into tensor<32x16x128x?xf16>
%17 = iree_linalg_ext.attention {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]}
ins(%arg0, %k_collapse, %v_collapse, %arg3, %arg4 : tensor<32x16x?x128xf16>, tensor<32x16x?x128xf16>, tensor<32x16x128x?xf16>, f16, tensor<32x16x?x?xf16>)
outs(%empty4 : tensor<32x16x?x128xf16>) {
^bb0(%arg8: f32):
iree_linalg_ext.yield %arg8 : f32
} -> tensor<32x16x?x128xf16>
util.return %17 : tensor<32x16x?x128xf16>
}
Metadata
Metadata
Assignees
Labels
No labels