diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 306e4b1f218e7..33fde665e8108 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -42,6 +42,8 @@ void populateTosaConstantReduction(MLIRContext *ctx, void populateTosaTypeConversion(TypeConverter &converter); std::unique_ptr createTosaTestQuantUtilAPIPass(); +std::unique_ptr +createTosaInputShapePass(std::vector args = {}); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index d005a4cc6859c..dd5c11f6aac7e 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -127,4 +127,25 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> { }]; } +def TosaInputShape : Pass<"experimental-tosa-input-shape", "func::FuncOp"> { + let summary = "Override dynamic input shapes of function arguments to specified static shapes."; + let description = [{ + Pass that overrides the dynamic input shapes of function arguments to specified static shapes. + It is an error if a specified static shape conflicts with the static dimensions in an original input shape. + }]; + + let constructor = "tosa::createTosaInputShapePass()"; + let dependentDialects = [ + "func::FuncDialect", + "tensor::TensorDialect", + "tosa::TosaDialect", + ]; + let options = [ + ListOption<"args", "args", "std::string", + "Comma-separated list of shape descriptions. Each description contains the " + "argument name, a colon, and a shape with dimensions separated by x " + "(e.g. arg0:5x5,arg3:2x64)">, + ]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index bbf079faea3d0..d9458886c0f95 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaTypeConverters.cpp TosaProfileCompliance.cpp TosaValidation.cpp + TosaInputShape.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp new file mode 100644 index 0000000000000..a78534daec31e --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInputShape.cpp @@ -0,0 +1,177 @@ +//===- TosaInputShape.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Change input shape of function argument to specified shape. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/FormatVariadic.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAINPUTSHAPE +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +std::pair>>, std::string> +parse_input_shapes(std::vector args) { + /** + * This function returns two values: a vector of parsed arguments, and an + * optional error message. Each arguments contains its argument number and the + * shape. For example: + * "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""} + * "args=arg0:" => {{}, "error message"} + */ + + std::vector>> shapes; + + for (std::string arg : args) { + if (arg.substr(0, 3) != "arg") { + return {{}, "Arguments must start with 'arg'"}; + } + + char *endptr; + size_t argnum = std::strtoul(&arg[3], &endptr, /*base=*/10); + if (*endptr != ':') { + return {{}, "Invalid argument name"}; + } + std::string shape_str = endptr + 1; + + std::vector curr; + while (!shape_str.empty()) { + size_t dim = std::strtoul(shape_str.data(), &endptr, /*base=*/10); + if ((*endptr != '\0' && *endptr != 'x') || shape_str == endptr) { + return {{}, "Invalid input shape description"}; + } + curr.push_back(dim); + if (*endptr == '\0') { + break; + } + shape_str = endptr + 1; + } + shapes.push_back({argnum, curr}); + } + return {shapes, ""}; +} + +/// Pass that change function input shapes to specified static input shapes +struct TosaInputShape : public tosa::impl::TosaInputShapeBase { +public: + TosaInputShape() = default; + explicit TosaInputShape(std::vector args) : TosaInputShape() { + this->args = args; + } + void runOnOperation() override { + func::FuncOp func = getOperation(); + auto [args_parsed, args_parse_err] = parse_input_shapes(args); + + if (!args_parse_err.empty()) { + func.emitError() << args_parse_err; + return; + } + + for (auto &block : func.getBody()) { + + for (auto [argnum, shape] : args_parsed) { + if (argnum >= block.getNumArguments()) { + func.emitError() << "arg" << argnum << " doesn't exist."; + return; + } + BlockArgument block_arg = block.getArgument(argnum); + Type arg_type = block_arg.getType(); + TensorType tensor_type = cast(arg_type); + if (failed( + mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) { + func->emitError() + << "arg" << argnum << " has incompatible shape with input shape."; + return; + } + SmallVector new_shape(shape.begin(), shape.end()); + auto new_tensor_type = + tensor_type.cloneWith(new_shape, tensor_type.getElementType()); + block_arg.setType(new_tensor_type); + } + + bool found_func_op = false; + + for (Operation &op : block) { + // Update result shape for func.func + func::FuncOp funcOp = mlir::dyn_cast(op.getParentOp()); + if (funcOp && !found_func_op) { + FunctionType old_function_type = funcOp.getFunctionType(); + std::vector inputs = old_function_type.getInputs(); + + for (auto [argnum, shape] : args_parsed) { + if ((size_t)argnum >= inputs.size()) { + func.emitError() << "arg" << argnum << " doesn't exist."; + return; + } + auto tensor_type = cast(inputs[argnum]); + + if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(), + shape))) { + funcOp->emitError() + << "arg" << argnum + << " has incompatible shape with input shape."; + return; + } + SmallVector new_shape(shape.begin(), shape.end()); + auto new_tensor_type = + tensor_type.cloneWith(new_shape, tensor_type.getElementType()); + inputs[argnum] = cast(new_tensor_type); + } + + FunctionType new_function_type = old_function_type.clone( + TypeRange{ArrayRef(inputs)}, + TypeRange{old_function_type.getResults()}); + funcOp.setFunctionType(new_function_type); + found_func_op = true; + } + // Update result shape of func.return + func::ReturnOp returnOp = mlir::dyn_cast(op); + if (returnOp) { + func::FuncOp funcOp = dyn_cast(op.getParentOp()); + if (funcOp) { + FunctionType old_function_type = funcOp.getFunctionType(); + FunctionType new_function_type = old_function_type.clone( + TypeRange{old_function_type.getInputs()}, + returnOp.getOperandTypes()); + funcOp.setFunctionType(new_function_type); + } + } + } + } + } +}; + +} // namespace + +std::unique_ptr +mlir::tosa::createTosaInputShapePass(std::vector args) { + return std::make_unique(args); +} diff --git a/mlir/test/Dialect/Tosa/tosa-input-shape.mlir b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir new file mode 100644 index 0000000000000..2a784aa3d33cb --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-input-shape.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt --split-input-file --experimental-tosa-input-shape="args=arg0:2x16,arg3:64x9" %s | FileCheck %s + +func.func @test_input_shape( + // CHECK: %arg0: tensor<2x16xi32> + %arg0: tensor<2x?xi32>, + // CHECK: %arg1: tensor + %arg1: tensor, + // CHECK: %arg2: tensor<2x?xi32> + %arg2: tensor<2x?xi32>, + // CHECK: %arg3: tensor<64x9xf32> + %arg3: tensor) -> (tensor<2x?xi32>, tensor) { + + // CHECK: %arg0, %arg3 : tensor<2x16xi32>, tensor<64x9xf32> + return %arg0, %arg3 : tensor<2x?xi32>, tensor +}