Description
Problem Description
I have a native rewrite function registered with PDL that does an element-by-element rewrite of Values in a ValueRange that is seemingly failing to populate the PDL ByteCode result list with the new value. My rewrite function looks like
static ValueRange getI32TensorSizes(PatternRewriter &rewriter,
ValueRange vals) {
SmallVector<Value> flatI32TensorSizes;
for (auto val : vals) {
if (isa<IndexType>(val.getType())) {
flatI32TensorSizes.push_back(rewriter.create<arith::IndexCastOp>(
val.getLoc(), rewriter.getIntegerType(32), val));
}
}
return ValueRange(flatI32TensorSizes);
}
patterns.getPDLPatterns().registerRewriteFunction(
"convert_index_to_i32",
getI32TensorSizes);
and I have pdl IR like
%workload = pdl.apply_native_rewrite "get_tensor_sizes"(%range : !pdl.range<value>) : !pdl.range<value>
%new_dims = pdl.apply_native_rewrite "convert_index_to_i32"(%workload : !pdl.range<value>) : !pdl.range<value>
If I run my interpreter pass with --debug-only=pdl-bytecode
, it successfully prints the arguments of to the native rewrite, but fails to print the results.
loc("/home/quinn/SHARK-Runtime/samples/custom_dispatch/vulkan/shaders/pattern_module.mlir":62:19)
Executing ApplyRewrite:
* Arguments: %dim = tensor.dim %1, %c0_0 : tensor<?xf32> * Result: Please report issues to https://github.com/openxla/iree/issues and include the crash backtrace.
(the full stack backtrace can be found here: https://gist.github.com/qedawkins/2f01e231caa8933c8c75c0b1a83b4d65, crashing here on this debug line:
llvm-project/mlir/lib/Rewrite/ByteCode.cpp
Line 1450 in a9136f0
If I print the surrounding IR immediately before returning from the native rewrite function, I see that the index_cast I wanted to insert is there
func.func @mixed_invocation(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c0 = arith.constant 0 : index
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<?xf32>{%0}
%2 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
%3 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<?xf32>{%2}
%c0_0 = arith.constant 0 : index
%dim = tensor.dim %1, %c0_0 : tensor<?xf32>
%4 = arith.index_cast %dim : index to i32
%5 = arith.mulf %1, %3 : tensor<?xf32>
%6 = arith.addf %5, %3 : tensor<?xf32>
%dim_1 = tensor.dim %6, %c0 : tensor<?xf32>
%7 = hal.tensor.export %6 "output 0" : tensor<?xf32>{%dim_1} -> !hal.buffer_view
return %7 : !hal.buffer_view
}
%4 = arith.index_cast %dim : index to i32