Skip to content

[MLIR][PDL] Invalid value from native rewrite #69888

Open
@qedawkins

Description

@qedawkins

Problem Description

I have a native rewrite function registered with PDL that does an element-by-element rewrite of Values in a ValueRange that is seemingly failing to populate the PDL ByteCode result list with the new value. My rewrite function looks like

static ValueRange getI32TensorSizes(PatternRewriter &rewriter,
                                               ValueRange vals) {
  SmallVector<Value> flatI32TensorSizes;
  for (auto val : vals) {
    if (isa<IndexType>(val.getType())) {
      flatI32TensorSizes.push_back(rewriter.create<arith::IndexCastOp>(
          val.getLoc(), rewriter.getIntegerType(32), val));
    }   
  }
  return ValueRange(flatI32TensorSizes);
}
  patterns.getPDLPatterns().registerRewriteFunction(
                               "convert_index_to_i32",
                               getI32TensorSizes);

and I have pdl IR like

%workload = pdl.apply_native_rewrite "get_tensor_sizes"(%range : !pdl.range<value>) : !pdl.range<value>
%new_dims = pdl.apply_native_rewrite "convert_index_to_i32"(%workload : !pdl.range<value>) : !pdl.range<value>

If I run my interpreter pass with --debug-only=pdl-bytecode, it successfully prints the arguments of to the native rewrite, but fails to print the results.

loc("/home/quinn/SHARK-Runtime/samples/custom_dispatch/vulkan/shaders/pattern_module.mlir":62:19)
Executing ApplyRewrite:
  * Arguments: %dim = tensor.dim %1, %c0_0 : tensor<?xf32>  * Result: Please report issues to https://github.com/openxla/iree/issues and include the crash backtrace.

(the full stack backtrace can be found here: https://gist.github.com/qedawkins/2f01e231caa8933c8c75c0b1a83b4d65, crashing here on this debug line:

LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
)

If I print the surrounding IR immediately before returning from the native rewrite function, I see that the index_cast I wanted to insert is there

func.func @mixed_invocation(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %c0 = arith.constant 0 : index
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<?xf32>{%0}
  %2 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %3 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<?xf32>{%2}
  %c0_0 = arith.constant 0 : index
  %dim = tensor.dim %1, %c0_0 : tensor<?xf32>
  %4 = arith.index_cast %dim : index to i32
  %5 = arith.mulf %1, %3 : tensor<?xf32>
  %6 = arith.addf %5, %3 : tensor<?xf32>
  %dim_1 = tensor.dim %6, %c0 : tensor<?xf32>
  %7 = hal.tensor.export %6 "output 0" : tensor<?xf32>{%dim_1} -> !hal.buffer_view
  return %7 : !hal.buffer_view
}
%4 = arith.index_cast %dim : index to i32

cc @MaheshRavishankar

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugIndicates an unexpected problem or unintended behaviormlir:pdl

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions