Skip to content

[Dispatch Creation] Prefer fusing broadcast with attention #22005

@IanWood1

Description

@IanWood1

In the case below, the broadcast should be fused with the attention op instead of the producer (bit-extend) operation.

util.func public @attention_broadcast(
  %arg0 : tensor<32x16x?x128xf16>, 
  %arg1 : tensor<4x?x8x128xf8E4M3FN>, 
  %arg2 : tensor<4x?x8x128xf8E4M3FN>, 
  %arg3 : f16, 
  %arg4 : tensor<32x16x?x?xf16>) -> (tensor<32x16x?x128xf16>) {
  %cst = arith.constant 1 : index
  %dim = tensor.dim %arg1, %cst : tensor<4x?x8x128xf8E4M3FN>
  %empty1 = tensor.empty(%dim) : tensor<4x?x8x128xf16>
  %empty2 = tensor.empty(%dim) : tensor<4x8x16x?x128xf16>
  %empty3 = tensor.empty(%dim) : tensor<4x8x16x128x?xf16>
  %empty4 = tensor.empty(%dim) : tensor<32x16x?x128xf16>
  %k = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
    ins(%arg1 : tensor<4x?x8x128xf8E4M3FN>)
    outs(%empty1 : tensor<4x?x8x128xf16>){
  ^bb0(%in: f8E4M3FN, %out: f16):
    %extf = arith.extf %in : f8E4M3FN to f16
    linalg.yield %extf : f16
  } -> tensor<4x?x8x128xf16>
  %k_bcast = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
    ins(%k : tensor<4x?x8x128xf16>)
    outs(%empty2 : tensor<4x8x16x?x128xf16>){
  ^bb0(%in: f16, %out: f16):
    linalg.yield %in : f16
  } -> tensor<4x8x16x?x128xf16>
  %k_collapse = tensor.collapse_shape %k_bcast [[0, 1], [2], [3], [4]] : tensor<4x8x16x?x128xf16> into tensor<32x16x?x128xf16>
  %v = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
    ins(%arg2 : tensor<4x?x8x128xf8E4M3FN>)
    outs(%empty1 : tensor<4x?x8x128xf16>){
  ^bb0(%in: f8E4M3FN, %out: f16):
    %extf = arith.extf %in : f8E4M3FN to f16
    linalg.yield %extf : f16
  } -> tensor<4x?x8x128xf16>
  %v_bcast = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
    ins(%v : tensor<4x?x8x128xf16>)
    outs(%empty3 : tensor<4x8x16x128x?xf16>){
  ^bb0(%in: f16, %out: f16):
    linalg.yield %in : f16
  } -> tensor<4x8x16x128x?xf16>
  %v_collapse = tensor.collapse_shape %v_bcast [[0, 1], [2], [3], [4]] : tensor<4x8x16x128x?xf16> into tensor<32x16x128x?xf16>
  %17 = iree_linalg_ext.attention {
    indexing_maps = [
      affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>, 
      affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>, 
      affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d5)>, 
      affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, 
      affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>, 
      affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]} 
    ins(%arg0, %k_collapse, %v_collapse, %arg3, %arg4 : tensor<32x16x?x128xf16>, tensor<32x16x?x128xf16>, tensor<32x16x128x?xf16>, f16, tensor<32x16x?x?xf16>)
    outs(%empty4 : tensor<32x16x?x128xf16>) {
  ^bb0(%arg8: f32):
      iree_linalg_ext.yield %arg8 : f32
  } -> tensor<32x16x?x128xf16>
  util.return %17 : tensor<32x16x?x128xf16>
}

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions