Cleanup compiler plugin directory and include paths. (#16691)
Progress on https://github.com/openxla/iree/issues/15468 (saw these
issues while working on migrating each plugin, finally got around to
forming a clear opinion and fixing)
* Qualify include paths relative to the repository root, not the plugin
root or some artificial prefix like `torch-iree/`
* Collapse directory structures, e.g. `StableHLO/stablehlo-iree/` to ->
`StableHLO/`
* Generate more CMake files from Bazel (still needed a few TableGen
tweaks though)
* Fixup `# ifdef` guards after code moves
* Fixup a few copyright headers that were using LLVM format instead of
IREE
diff --git a/compiler/plugins/input/TOSA/InputConversion/BUILD.bazel b/compiler/plugins/input/TOSA/InputConversion/BUILD.bazel
new file mode 100644
index 0000000..d0f0117
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/BUILD.bazel
@@ -0,0 +1,83 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_gentbl_cc_library(
+ name = "PassesIncGen",
+ tbl_outs = [
+ (
+ ["--gen-pass-decls"],
+ "Passes.h.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "Passes.td",
+ deps = [
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
+
+iree_compiler_cc_library(
+ name = "PassHeaders",
+ hdrs = [
+ "PassDetail.h",
+ "Passes.h",
+ "Passes.h.inc",
+ ],
+ deps = [
+ ":PassesIncGen",
+ "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FunctionInterfaces",
+ "@llvm-project//mlir:LinalgDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+iree_compiler_cc_library(
+ name = "InputConversion",
+ srcs = [
+ "Converti48Toi64.cpp",
+ "Passes.cpp",
+ "StripSignedness.cpp",
+ "TosaToLinalgExt.cpp",
+ "VerifyCompilerTOSAInputLegality.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ ],
+ deps = [
+ ":PassHeaders",
+ ":PassesIncGen",
+ "//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
+ "//compiler/src/iree/compiler/InputConversion/Common",
+ "@llvm-project//mlir:ArithDialect",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FunctionInterfaces",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFToControlFlow",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TosaToArith",
+ "@llvm-project//mlir:TosaToLinalg",
+ "@llvm-project//mlir:TosaToMLProgram",
+ "@llvm-project//mlir:TosaToSCF",
+ "@llvm-project//mlir:TosaToTensor",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt b/compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt
new file mode 100644
index 0000000..a1dc1ea
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt
@@ -0,0 +1,76 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/plugins/input/TOSA/InputConversion/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_tablegen_library(
+ NAME
+ PassesIncGen
+ TD_FILE
+ "Passes.td"
+ OUTS
+ --gen-pass-decls Passes.h.inc
+)
+
+iree_cc_library(
+ NAME
+ PassHeaders
+ HDRS
+ "PassDetail.h"
+ "Passes.h"
+ "Passes.h.inc"
+ DEPS
+ ::PassesIncGen
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRFunctionInterfaces
+ MLIRLinalgDialect
+ MLIRPass
+ MLIRTensorDialect
+ MLIRTransforms
+ iree::compiler::Dialect::LinalgExt::IR
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ InputConversion
+ HDRS
+ "Passes.h"
+ SRCS
+ "Converti48Toi64.cpp"
+ "Passes.cpp"
+ "StripSignedness.cpp"
+ "TosaToLinalgExt.cpp"
+ "VerifyCompilerTOSAInputLegality.cpp"
+ DEPS
+ ::PassHeaders
+ ::PassesIncGen
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRFunctionInterfaces
+ MLIRIR
+ MLIRLinalgDialect
+ MLIRPass
+ MLIRSCFToControlFlow
+ MLIRTensorDialect
+ MLIRTosaDialect
+ MLIRTosaToArith
+ MLIRTosaToLinalg
+ MLIRTosaToMLProgram
+ MLIRTosaToSCF
+ MLIRTosaToTensor
+ MLIRTransforms
+ iree::compiler::Dialect::LinalgExt::IR
+ iree::compiler::InputConversion::Common
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
new file mode 100644
index 0000000..570982c
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
@@ -0,0 +1,182 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace mlir::iree_compiler {
+
+class Converti48Toi64Pass : public Converti48Toi64Base<Converti48Toi64Pass> {
+public:
+ explicit Converti48Toi64Pass() = default;
+ void runOnOperation() override;
+};
+
+struct i48Toi64Converter : public TypeConverter {
+public:
+ static Type convertType(Type type) {
+ if (type.isInteger(48)) {
+ return IntegerType::get(type.getContext(), /*width=*/64);
+ }
+ return type;
+ }
+ static Type convertTensor(RankedTensorType type) {
+ auto newType = RankedTensorType::get(type.getShape(),
+ convertType(type.getElementType()));
+ return newType;
+ }
+ explicit i48Toi64Converter() {
+ addConversion([](Type type) { return convertType(type); });
+ addConversion(convertTensor);
+ }
+};
+
+// Handles the type conversion component of the TypeConversion. This updates
+// conversion patterns that used the original i48 tensor types to be
+// updated to the i64 variants.
+class GenericTypeConvert : public ConversionPattern {
+public:
+ GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<Type, 4> newResults;
+ if (isa<mlir::FunctionOpInterface>(op)) {
+ return rewriter.notifyMatchFailure(op, "is a func op");
+ }
+
+ llvm::SmallVector<Type, 4> oldAttrTypes;
+ llvm::SmallVector<unsigned, 4> typedIndices;
+
+ // Extract the typed attributes for conversion.
+ for (auto [index, attr] : llvm::enumerate(op->getAttrs())) {
+ if (auto typedAttr = attr.getValue().dyn_cast<TypedAttr>()) {
+ oldAttrTypes.push_back(typedAttr.getType());
+ typedIndices.push_back(index);
+ }
+ }
+
+ llvm::SmallVector<Type, 4> newAttrTypes;
+ (void)getTypeConverter()->convertTypes(oldAttrTypes, newAttrTypes);
+
+ llvm::SmallVector<NamedAttribute, 4> newAttrs(op->getAttrs());
+ for (auto [idx, typedIndex] : llvm::enumerate(typedIndices)) {
+ auto attrValue = newAttrs[typedIndex].getValue();
+ auto newAttrType = newAttrTypes[idx];
+
+ // For integer attributes, create a new integer of new width.
+ if (auto intAttr = dyn_cast<IntegerAttr>(attrValue)) {
+ if (auto intType = dyn_cast<IntegerType>(newAttrType)) {
+ auto value =
+ IntegerAttr::get(intType, intAttr.getValue().getZExtValue());
+ newAttrs[typedIndex] =
+ NamedAttribute(newAttrs[typedIndex].getName(), value);
+ continue;
+ }
+ }
+
+ // For shaped types, map the values to the new types.
+ if (auto shapedType = dyn_cast<ShapedType>(newAttrType)) {
+ if (auto denseAttr = dyn_cast<DenseIntElementsAttr>(attrValue)) {
+ auto eType = shapedType.getElementType().dyn_cast<IntegerType>();
+ auto cast = [&](APInt value) {
+ return APInt(eType.getWidth(), value.getZExtValue());
+ };
+ auto newDenseAttr = denseAttr.mapValues(eType, cast);
+ newAttrs[typedIndex] =
+ NamedAttribute(newAttrs[typedIndex].getName(), newDenseAttr);
+ continue;
+ }
+ }
+ return rewriter.notifyMatchFailure(op, "Unsupported input type");
+ }
+
+ (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, newAttrs, op->getSuccessors());
+ for (Region &r : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ (void)getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result);
+ rewriter.applySignatureConversion(newRegion, result);
+ }
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+static bool isIllegalType(Type type) {
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
+ return isIllegalType(shapedType.getElementType());
+ }
+ return type.isInteger(48);
+}
+
+void Converti48Toi64Pass::runOnOperation() {
+ i48Toi64Converter converter;
+ ConversionTarget target(getContext());
+
+ // Operations are legal if they don't contain any illegal type.
+ target.markUnknownOpDynamicallyLegal([](Operation *op) {
+ if (auto funcOp = dyn_cast<mlir::FunctionOpInterface>(op)) {
+ for (Type type : funcOp.getArgumentTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ for (Type type : funcOp.getResultTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ }
+ for (Type type : op->getResultTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ for (Type type : op->getOperandTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ for (auto attr : op->getAttrs()) {
+ if (auto typedAttr = attr.getValue().dyn_cast<TypedAttr>()) {
+ if (isIllegalType(typedAttr.getType())) {
+ return false;
+ }
+ }
+ }
+ return true;
+ });
+
+ auto *ctx = &getContext();
+ auto func = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ patterns.add<GenericTypeConvert>(ctx, converter);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+
+ if (failed(applyFullConversion(func, target, std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createConverti48Toi64() {
+ return std::make_unique<Converti48Toi64Pass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/plugins/input/TOSA/InputConversion/PassDetail.h b/compiler/plugins/input/TOSA/InputConversion/PassDetail.h
new file mode 100644
index 0000000..6444c75
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/PassDetail.h
@@ -0,0 +1,25 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSDETAIL_H_
+#define IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSDETAIL_H_
+
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_CLASSES
+#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc"
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSDETAIL_H_
diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.cpp b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
new file mode 100644
index 0000000..85a8cac
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
@@ -0,0 +1,93 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
+
+#include "iree/compiler/InputConversion/Common/Passes.h"
+#include "mlir/Conversion/TosaToArith/TosaToArith.h"
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir::iree_compiler {
+
+void registerTOSAConversionPassPipeline() {
+ PassPipelineRegistration<> tosa(
+ "iree-tosa-input-transformation-pipeline",
+ "Runs the TOSA IREE flow dialect transformation pipeline",
+ [](OpPassManager &passManager) {
+ buildTOSAInputConversionPassPipeline(passManager);
+ });
+}
+
+// Prepare TOSA for use as an input to the Flow dialect.
+void buildTOSAInputConversionPassPipeline(OpPassManager &passManager) {
+ passManager.addPass(mlir::createTosaToMLProgram());
+ // Currently we don't handle SCF ops well and have to convert them all to CFG.
+ // In the future it would be nice if we could have all of flow be both scf
+ // and cfg compatible.
+ passManager.addNestedPass<func::FuncOp>(tosa::createTosaToSCF());
+
+ // We also don't handle calls well on the old codepath; until we remove the
+ // use of the CFG we can continue inlining.
+ passManager.addPass(mlir::createInlinerPass());
+
+ passManager.addNestedPass<func::FuncOp>(
+ tosa::createTosaMakeBroadcastablePass());
+ passManager.addNestedPass<func::FuncOp>(tosa::createTosaToArith());
+ passManager.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());
+ passManager.addNestedPass<func::FuncOp>(
+ iree_compiler::createTosaToLinalgExt());
+ passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
+
+ TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
+ tosaToLinalgNamedOptions.preferConv2DKernelLayoutHWCF = true;
+ tosa::addTosaToLinalgPasses(passManager, TosaToLinalgOptions(),
+ tosaToLinalgNamedOptions);
+ passManager.addNestedPass<func::FuncOp>(
+ iree_compiler::createConverti48Toi64());
+
+ // Sometimes we generate more TOSA operations during the lowering to linalg.
+ passManager.addNestedPass<func::FuncOp>(tosa::createTosaToArith());
+ passManager.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());
+
+ passManager.addNestedPass<func::FuncOp>(
+ iree_compiler::createStripSignednessPass());
+ passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
+
+ passManager.addNestedPass<func::FuncOp>(
+ createLinalgQuantizedMatmulToMatmulPass());
+ passManager.addNestedPass<func::FuncOp>(
+ createLinalgQuantizedConvToConvPass());
+ passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
+
+ //----------------------------------------------------------------------------
+ // Entry dialect cleanup
+ //----------------------------------------------------------------------------
+ passManager.addPass(createVerifyCompilerTOSAInputLegality());
+}
+
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" // IWYU pragma: export
+} // namespace
+
+void registerTOSAConversionPasses() {
+ // Generated.
+ registerPasses();
+
+ // Pipelines.
+ registerTOSAConversionPassPipeline();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.h b/compiler/plugins/input/TOSA/InputConversion/Passes.h
new file mode 100644
index 0000000..f47e4b1
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/Passes.h
@@ -0,0 +1,57 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES_H_
+#define IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES_H_
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler {
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
+// Performs input legalization for specific combination of input dialects.
+void buildTOSAInputConversionPassPipeline(OpPassManager &passManager);
+
+void registerTOSAConversionPassPipeline();
+
+//------------------------------------------------------------------------------
+// Conversions from TOSA into Linalg and other core IREE dialects
+//------------------------------------------------------------------------------
+
+// Set of patterns for materializing TOSA operations to linalg_ext.
+void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns);
+
+// Converts i48 to i64.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createConverti48Toi64();
+
+// Strips the signed/unsigned portion off of tensors.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass();
+
+// Converts TOSA operations to linalg_ext.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createTosaToLinalgExt();
+
+// Verifies that a module only contains IR structures that are supported by the
+// core compiler.
+std::unique_ptr<OperationPass<ModuleOp>>
+createVerifyCompilerTOSAInputLegality();
+
+//===----------------------------------------------------------------------===//
+// Register all Passes
+//===----------------------------------------------------------------------===//
+
+void registerTOSAConversionPasses();
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES_H_
diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.td b/compiler/plugins/input/TOSA/InputConversion/Passes.td
new file mode 100644
index 0000000..7ea82d1
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/Passes.td
@@ -0,0 +1,42 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES
+#define IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def Converti48Toi64 :
+ InterfacePass<"iree-tosa-convert-i48-to-i64", "mlir::FunctionOpInterface"> {
+ let summary = "Converts all i48s to i64s";
+ let constructor = "mlir::iree_compiler::createConverti48Toi64()";
+}
+
+def StripSignedness :
+ InterfacePass<"iree-tosa-strip-signedness", "mlir::FunctionOpInterface"> {
+ let summary = "Legalizes ui tensors constants to uis";
+ let constructor = "mlir::iree_compiler::createStripSignednessPass()";
+}
+
+def TosaToLinalgExt :
+ InterfacePass<"iree-tosa-to-linalg-ext", "mlir::FunctionOpInterface"> {
+ let summary = "Convert TOSA operations to their equivalent linalg-ext operations.";
+ let constructor = "mlir::iree_compiler::createTosaToLinalgExt()";
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "linalg::LinalgDialect",
+ "tensor::TensorDialect",
+ "IREE::LinalgExt::IREELinalgExtDialect",
+ ];
+}
+
+def VerifyCompilerTOSAInputLegality :
+ Pass<"iree-tosa-verify-compiler-input-legality", "ModuleOp"> {
+ let summary = "Verifies that only supported IR constructs are passed to the compiler.";
+ let constructor = "mlir::iree_compiler::createVerifyCompilerTOSAInputLegality()";
+}
+
+#endif // IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES
diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
new file mode 100644
index 0000000..dc3ba5b
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
@@ -0,0 +1,133 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
+#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+class StripSignednessPass : public StripSignednessBase<StripSignednessPass> {
+public:
+ explicit StripSignednessPass() {}
+ void runOnOperation() override;
+};
+
+class IntegerTypeConverter : public TypeConverter {
+public:
+ static Type convertType(Type type) {
+ if (auto iType = llvm::dyn_cast<IntegerType>(type)) {
+ if (!iType.isSignless()) {
+ return IntegerType::get(type.getContext(),
+ iType.getIntOrFloatBitWidth());
+ }
+ }
+ return type;
+ }
+ static Type convertTensor(RankedTensorType type) {
+ auto newType = RankedTensorType::get(type.getShape(),
+ convertType(type.getElementType()));
+ return newType;
+ }
+ explicit IntegerTypeConverter() {
+ addConversion([](Type type) { return convertType(type); });
+ addConversion(convertTensor);
+ }
+};
+
+// Handles the type conversion component of the TypeConversion. This updates
+// conversion patterns that used the original Quant types to be updated to
+// the non-quant variants.
+class GenericTypeConvert : public ConversionPattern {
+public:
+ GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<Type> newResults;
+ if (isa<FunctionOpInterface>(op)) {
+ return failure();
+ }
+
+ (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, op->getAttrs(), op->getSuccessors());
+ for (Region &r : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ (void)getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result);
+ rewriter.applySignatureConversion(newRegion, result);
+ }
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+static bool isIllegalType(Type type) {
+ if (IntegerType ity = llvm::dyn_cast<IntegerType>(type))
+ return !ity.isSignless();
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) {
+ return isIllegalType(shapedType.getElementType());
+ }
+ return false;
+}
+
+void StripSignednessPass::runOnOperation() {
+ IntegerTypeConverter converter;
+ ConversionTarget target(getContext());
+
+ // Operations are legal if they don't contain any illegal type.
+ target.markUnknownOpDynamicallyLegal([](Operation *op) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ for (Type type : funcOp.getArgumentTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ for (Type type : funcOp.getResultTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ }
+ for (Type type : op->getResultTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ for (Type type : op->getOperandTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ return true;
+ });
+
+ auto *ctx = &getContext();
+
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<GenericTypeConvert>(ctx, converter);
+ populateFunctionOpInterfaceTypeConversionPattern(
+ getOperation()->getName().getStringRef(), patterns, converter);
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass() {
+ return std::make_unique<StripSignednessPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp
new file mode 100644
index 0000000..799dd04
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp
@@ -0,0 +1,171 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
+#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace mlir::iree_compiler {
+
+// Converts tosa.scatter to the iree_linalg_ext.scatter operation. As the
+// LinalgExt version is not batched therefore we materialize the batch index
+// for each update.
+class ScatterConversion : public OpRewritePattern<tosa::ScatterOp> {
+public:
+ using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ScatterOp op,
+ PatternRewriter &rewriter) const final {
+ auto values = op.getValuesIn();
+ auto indices = llvm::cast<Value>(op.getIndices());
+ auto updates = op.getInput();
+ auto valuesTy = llvm::dyn_cast<RankedTensorType>(values.getType());
+ auto indicesTy = llvm::dyn_cast<RankedTensorType>(indices.getType());
+ auto updatesTy = llvm::dyn_cast<RankedTensorType>(updates.getType());
+ ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+
+ if (!valuesTy || !indicesTy || !updatesTy)
+ return rewriter.notifyMatchFailure(op,
+ "tosa.gather has unknown input rank");
+
+ // TOSA's scatter does not include a index dimension, instead it implicitly
+ // supports an index depth of one. We materialize that implicit index of
+ // one as follows: [batch, updates] -> [batch, updates, index_depth=1] With
+ // a indexing map of [[0], [1, 2]].
+ llvm::SmallVector<int64_t> expandIndShape{indicesTy.getDimSize(0),
+ indicesTy.getDimSize(1), 1};
+ SmallVector<ReassociationExprs> expandIndMap;
+ expandIndMap.push_back({
+ builder.getAffineDimExpr(0),
+ });
+ expandIndMap.push_back({
+ builder.getAffineDimExpr(1),
+ builder.getAffineDimExpr(2),
+ });
+
+ indices = builder.create<tensor::ExpandShapeOp>(
+ indicesTy.clone(expandIndShape), indices, expandIndMap);
+ indicesTy = llvm::dyn_cast<RankedTensorType>(indices.getType());
+
+ // Materialize the batch indice as LinalgExt scatter is not batched.
+ {
+ llvm::SmallVector<Value> dynDims;
+ for (int i = 0, s = indicesTy.getRank(); i < s; ++i)
+ if (indicesTy.isDynamicDim(i))
+ dynDims.push_back(builder.create<tensor::DimOp>(indices, i));
+
+ Value empty = builder.create<tensor::EmptyOp>(
+ indicesTy.getShape(), indicesTy.getElementType(), dynDims);
+
+ Value batchIdx = nullptr;
+
+ if (indicesTy.getDimSize(0) == 1) {
+ Value zero = builder.create<arith::ConstantOp>(
+ rewriter.getZeroAttr(indicesTy.getElementType()));
+ batchIdx = builder.create<linalg::FillOp>(zero, empty).getResult(0);
+ } else {
+ SmallVector<utils::IteratorType> iterators(
+ indicesTy.getRank(), utils::IteratorType::parallel);
+ SmallVector<AffineMap, 3> indexingMaps(
+ 2, builder.getMultiDimIdentityMap(indicesTy.getRank()));
+
+ auto blockBuilder = [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange blockArgs) {
+ ImplicitLocOpBuilder b(op.getLoc(), nestedBuilder);
+ auto index = b.create<linalg::IndexOp>(0);
+ auto cast =
+ b.create<arith::IndexCastOp>(indicesTy.getElementType(), index);
+ b.create<linalg::YieldOp>(cast.getResult());
+ };
+ batchIdx = builder
+ .create<linalg::GenericOp>(indicesTy, indices, empty,
+ indexingMaps, iterators,
+ blockBuilder)
+ .getResult(0);
+ }
+
+ indicesTy = llvm::cast<RankedTensorType>(indicesTy.clone(
+ {indicesTy.getDimSize(0), indicesTy.getDimSize(1), 2}));
+ indices = builder.create<tosa::ConcatOp>(indicesTy,
+ ValueRange{batchIdx, indices},
+ rewriter.getI32IntegerAttr(2));
+ }
+
+ auto collapseBatch = [](Value value, ImplicitLocOpBuilder &b) -> Value {
+ auto valueTy = llvm::cast<ShapedType>(value.getType());
+ llvm::SmallVector<int64_t> collapseShape(valueTy.getShape().drop_front());
+ llvm::SmallVector<ReassociationExprs> collapseMap(valueTy.getRank() - 1);
+ collapseMap.front().push_back(b.getAffineDimExpr(0));
+ for (int i = 0, s = collapseMap.size(); i < s; ++i) {
+ collapseMap[i].push_back(b.getAffineDimExpr(i + 1));
+ }
+
+ int64_t batch = valueTy.getShape().front();
+ int64_t rows = collapseShape.front();
+ bool batchDyn = ShapedType::isDynamic(batch);
+ bool rowsDyn = ShapedType::isDynamic(rows);
+ collapseShape[0] =
+ (batchDyn || rowsDyn) ? ShapedType::kDynamic : batch * rows;
+
+ return b.create<tensor::CollapseShapeOp>(valueTy.clone(collapseShape),
+ value, collapseMap);
+ };
+
+ indices = collapseBatch(indices, builder);
+ updates = collapseBatch(updates, builder);
+
+ // Create the LinalgExt scatter operation.
+ auto scatter = builder.create<IREE::LinalgExt::ScatterOp>(
+ TypeRange{values.getType()}, ValueRange{updates, indices},
+ ValueRange{values}, builder.getDenseI64ArrayAttr({0, 1}),
+ builder.getBoolAttr(true));
+
+ llvm::SmallVector<Type> args(2, valuesTy.getElementType());
+ Block *scatterBody =
+ builder.createBlock(&scatter.getRegion(), {}, args,
+ llvm::SmallVector<Location>(2, op.getLoc()));
+ builder.setInsertionPointToStart(scatterBody);
+ builder.create<IREE::LinalgExt::YieldOp>(scatterBody->getArgument(0));
+ rewriter.replaceOp(op, scatter.getResult(0));
+ return success();
+ }
+};
+
+struct TosaToLinalgExtPass : public TosaToLinalgExtBase<TosaToLinalgExtPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ ConversionTarget target(getContext());
+ target.addIllegalOp<tosa::ScatterOp>();
+ target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+ FunctionOpInterface func = getOperation();
+ mlir::iree_compiler::populateTosaToLinalgExtPatterns(&patterns);
+ if (failed(applyFullConversion(func, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns) {
+ patterns->add<ScatterConversion>(patterns->getContext());
+}
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createTosaToLinalgExt() {
+ return std::make_unique<TosaToLinalgExtPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/plugins/input/TOSA/InputConversion/VerifyCompilerTOSAInputLegality.cpp b/compiler/plugins/input/TOSA/InputConversion/VerifyCompilerTOSAInputLegality.cpp
new file mode 100644
index 0000000..dd0f275
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/VerifyCompilerTOSAInputLegality.cpp
@@ -0,0 +1,71 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
+#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+struct VerifyCompilerTOSAInputLegalityPass
+ : public VerifyCompilerTOSAInputLegalityBase<
+ VerifyCompilerTOSAInputLegalityPass> {
+ void runOnOperation() override {
+ auto *context = &getContext();
+ ConversionTarget conversionTarget(*context);
+ RewritePatternSet conversionPatterns(&getContext());
+
+ // Note that we would prefer allow-lists of what we positively support.
+ // However, it is so common to sneak input-level ops into the pipeline
+ // that we explicitly deny the dialects we know about.
+ conversionTarget.addIllegalDialect<tosa::TosaDialect>();
+
+ // Exception: ApplyScaleOp is actually a lowered op on par with standard
+ // dialect.
+ conversionTarget.addLegalOp<tosa::ApplyScaleOp>();
+
+ // NOTE: It is not fully illegal to tunnel input dialect ops through to
+ // backends that expect them. When such situations arise, the container
+ // op should be marked recursively legal here.
+ SmallVector<Diagnostic> failures;
+ {
+ ScopedDiagnosticHandler diag(context,
+ [&](Diagnostic &d) -> LogicalResult {
+ failures.push_back(std::move(d));
+ return success();
+ });
+ if (succeeded(applyPartialConversion(getOperation(), conversionTarget,
+ std::move(conversionPatterns)))) {
+ return;
+ }
+ }
+
+ // Error fall-through. Attach all reported issues as notes.
+ InFlightDiagnostic errorDiag =
+ emitError(getOperation().getLoc())
+ << "one or more illegal operations were found in the compiler input "
+ "(are you missing an --iree-input-type= flag, or did you mean to "
+ "pre-process through an IREE importer frontend?)";
+ for (auto &failureDiag : failures) {
+ Diagnostic ¬e = errorDiag.attachNote(failureDiag.getLocation());
+ for (auto &arg : failureDiag.getArguments()) {
+ note.append(arg);
+ }
+ }
+
+ signalPassFailure();
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createVerifyCompilerTOSAInputLegality() {
+ return std::make_unique<VerifyCompilerTOSAInputLegalityPass>();
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
new file mode 100644
index 0000000..2e692a5
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
@@ -0,0 +1,40 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "apply_pdl_patterns_tosa.mlir",
+ "auto_input_conversion.mlir",
+ "convert_i48_to_i64.mlir",
+ "strip_signedness.mlir",
+ "tosa_to_linalg_ext.mlir",
+ "verify_compiler_tosa_input_legality.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ "tosa.pdl.mlir",
+ ],
+ ),
+ cfg = "//compiler:lit.cfg.py",
+ data = [
+ "tosa.pdl.mlir",
+ ],
+ tools = [
+ "//tools:iree-compile",
+ "//tools:iree-opt",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
new file mode 100644
index 0000000..6c58e41
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
@@ -0,0 +1,31 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "apply_pdl_patterns_tosa.mlir"
+ "auto_input_conversion.mlir"
+ "convert_i48_to_i64.mlir"
+ "strip_signedness.mlir"
+ "tosa_to_linalg_ext.mlir"
+ "verify_compiler_tosa_input_legality.mlir"
+ TOOLS
+ FileCheck
+ iree-compile
+ iree-opt
+ DATA
+ tosa.pdl.mlir
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/apply_pdl_patterns_tosa.mlir b/compiler/plugins/input/TOSA/InputConversion/test/apply_pdl_patterns_tosa.mlir
new file mode 100644
index 0000000..c889013
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/apply_pdl_patterns_tosa.mlir
@@ -0,0 +1,52 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(iree-preprocessing-apply-pdl-patterns{patterns-file=%p/tosa.pdl.mlir})" %s | FileCheck %s
+
+// CHECK-LABEL: stream.executable private @mlp_external_executable
+// CHECK: stream.executable.export public @mlp_external_entry_point
+// CHECK: builtin.module
+// CHECK: func.func private @mlp_external
+// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32)
+// CHECK-SAME: attributes {llvm.bareptr = [true]}
+// CHECK: func.func @mlp_external_entry_point
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !stream.binding
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<1x2x4xf32>
+// CHECK-NEXT: %[[STREAM0_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM0]]
+// CHECK: %[[STREAM1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<1x4x8xf32>
+// CHECK-NEXT: %[[STREAM1_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM1]]
+// CHECK: %[[STREAM2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<1x2x8xf32>
+// CHECK-NEXT: %[[STREAM2_BASE:[a-zA-Z0-9_]+]],
+// CHECK-SAME: = memref.extract_strided_metadata %[[STREAM2]]
+// CHECK: call @mlp_external
+// CHECK-SAME: %[[STREAM0_BASE]], %[[C0]], %[[STREAM1_BASE]], %[[C0]], %[[STREAM2_BASE]], %[[C0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+
+// CHECK: func.func @mlp_invocation
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x4xf32>, %[[ARG1:.+]]: tensor<4x8xf32>)
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i32
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : i32
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
+// CHECK-DAG: %[[LHS:.+]] = tosa.reshape %[[ARG0]]
+// CHECK-DAG: %[[RHS:.+]] = tosa.reshape %[[ARG1]]
+// CHECK: %[[RESULT:.+]] = flow.dispatch
+// CHECK-SAME: @mlp_external_executable::@mlp_external_entry_point
+// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[C2]], %[[C8]], %[[C4]])
+// CHECK: tosa.negate %[[RESULT]]
+
+func.func @mlp_invocation(%lhs: tensor<2x4xf32>, %rhs : tensor<4x8xf32>) -> tensor<2x8xf32> {
+ %lhs_3D = tosa.reshape %lhs {new_shape = array<i64 : 1, 2, 2>} : (tensor<2x4xf32>) -> tensor<1x2x4xf32>
+ %rhs_3D = tosa.reshape %rhs {new_shape = array<i64 : 1, 2, 2>} : (tensor<4x8xf32>) -> tensor<1x4x8xf32>
+ %0 = tosa.matmul %lhs_3D, %rhs_3D : (tensor<1x2x4xf32>, tensor<1x4x8xf32>) -> tensor<1x2x8xf32>
+ %1 = tosa.clamp %0 {
+ min_int = 0 : i64, max_int = 9223372036854775807 : i64,
+ min_fp = 0.0 : f32, max_fp = 3.4028235e+38 : f32}
+ : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+ %2 = tosa.negate %1 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
+ %3 = tosa.reshape %2 {new_shape = array<i64 : 2, 2>} : (tensor<1x2x8xf32>) -> tensor<2x8xf32>
+ return %3 : tensor<2x8xf32>
+}
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/auto_input_conversion.mlir b/compiler/plugins/input/TOSA/InputConversion/test/auto_input_conversion.mlir
new file mode 100644
index 0000000..145f2d9
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/auto_input_conversion.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-compile --compile-to=input --split-input-file %s | FileCheck %s
+
+// Check that the auto input conversion pipeline uses this plugin.
+
+// CHECK-LABEL: util.func public @simple_add_tosa
+// CHECK: arith.addi
+func.func @simple_add_tosa(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
+ %0 = tosa.add %arg0, %arg1 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %0 : tensor<2x2xi32>
+}
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/convert_i48_to_i64.mlir b/compiler/plugins/input/TOSA/InputConversion/test/convert_i48_to_i64.mlir
new file mode 100644
index 0000000..5e0a558
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/convert_i48_to_i64.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-tosa-convert-i48-to-i64))" --verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @test_all_i48_converted
+func.func @test_all_i48_converted(%arg0: tensor<2x2xi48>) -> tensor<2x2xi48> {
+ // CHECK: %[[ADD:.+]] = tosa.add %arg0, %arg0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
+ // CHECK: %[[SUB:.+]] = tosa.sub %[[ADD]], %arg0 : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
+ // CHECK: return %[[SUB]] : tensor<2x2xi64>
+ %0 = tosa.add %arg0, %arg0 : (tensor<2x2xi48>, tensor<2x2xi48>) -> tensor<2x2xi48>
+ %1 = tosa.sub %0, %arg0 : (tensor<2x2xi48>, tensor<2x2xi48>) -> tensor<2x2xi48>
+ return %1 : tensor<2x2xi48>
+}
+
+// CHECK-LABEL: @test_other_types_not_converted
+func.func @test_other_types_not_converted(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {
+ // CHECK: %[[ADD:.+]] = tosa.add %arg0, %arg0 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ // CHECK: %[[SUB:.+]] = tosa.sub %[[ADD]], %arg0 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ // CHECK: return %[[SUB]] : tensor<2x2xi32>
+ %0 = tosa.add %arg0, %arg0 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ %1 = tosa.sub %0, %arg0 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
+ return %1 : tensor<2x2xi32>
+}
+
+// CHECK-LABEL: @test_attrs_converted
+func.func @test_attrs_converted() -> (i48, tensor<2xi48>) {
+ // CHECK: %[[ARITH_C:.+]] = arith.constant 0 : i64
+ // CHECK: %[[TOSA_C:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi64>}> : () -> tensor<2xi64>
+ // CHECK: return %[[ARITH_C]], %[[TOSA_C]] : i64, tensor<2xi64>
+ %0 = "arith.constant"() {value = 0 : i48} : () -> i48
+ %1 = "tosa.const"() <{value = dense<0> : tensor<2xi48>}> : () -> tensor<2xi48>
+ return %0, %1 : i48, tensor<2xi48>
+}
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/strip_signedness.mlir b/compiler/plugins/input/TOSA/InputConversion/test/strip_signedness.mlir
new file mode 100644
index 0000000..cf30e57
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/strip_signedness.mlir
@@ -0,0 +1,22 @@
+
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-tosa-strip-signedness))' %s | FileCheck %s
+
+// CHECK-LABEL: @strip_signedness_arg
+// CHECK-SAME: tensor<4xi8>
+func.func @strip_signedness_arg(%arg0 : tensor<4xui8>) -> (tensor<4xui8>) {
+ // CHECK: return
+ // CHECK-SAME: tensor<4xi8>
+ return %arg0 : tensor<4xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @strip_signedness_const
+func.func @strip_signedness_const() -> (tensor<4xi8>) {
+ // CHECK: constant
+ // CHECK-SAME: tensor<4xi8>
+ %0 = arith.constant dense<[0, 2, 3, 7]> : tensor<4xi8>
+ // CHECK: return
+ // CHECK-SAME: tensor<4xi8>
+ return %0 : tensor<4xi8>
+}
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/tosa.pdl.mlir b/compiler/plugins/input/TOSA/InputConversion/test/tosa.pdl.mlir
new file mode 100644
index 0000000..e4f02da
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/tosa.pdl.mlir
@@ -0,0 +1,125 @@
+// PDL pattern spec to match an MLP and offload to an external function
+//
+// ```
+// void mlp_external(void *params, void *context, void *reserved)
+// ```
+//
+// which is the expected signature of an external function implemented
+// provided by a system plugin. See
+// samples/custom_dispatch/cpu/plugin/system_plugin.c for an example.
+//
+// The `params` is the following struct
+//
+// ```
+// struct mlp_params_t {
+// const float *restrict lhs;
+// size_t lhs_offset;
+// const float *restrict rhs;
+// size_t rhs_offset;
+// int32_t M;
+// int32_t N;
+// int32_t K;
+// float *restrict result;
+// size_t result_offset;
+// };
+// ```
+//
+// In MLIR this corresponds to the function
+//
+// ```
+// func.func @mlp_external(%lhs : memref<..xf32>, %rhs : memref<..xf32>,
+// %M: i32, %N : i32, %K : i32, %result : memref<..xf32>)
+// ```
+//
+// Note: In the above struct a `pointer, offset` pair represents a buffer
+// passed into the external function. So any access to `lhs`, `rhs` and
+// `result` is valid only if accessed as `lhs[lhs_offset + ...]`,
+// `rhs[rhs_offset + ]` and `result[result_offset + ...]`.
+pdl.pattern @mlp : benefit(1) {
+
+ // PDL matcher to match the MLP computation. This pattern is expected to
+ // match
+ //
+ // ```
+ // %result = func.call @mlp_external(%lhs : tensor<...xf32>,
+ // %rhs : tensor<..xf32>, %M : i32, %N : i32, %K : i32) -> tensor<..xf32>
+ // ```
+ %lhs_type = pdl.type
+ %lhs = pdl.operand : %lhs_type
+ %rhs_type = pdl.type
+ %rhs = pdl.operand : %rhs_type
+ %matmul_type = pdl.type
+ %min_int = pdl.attribute = 0 : i64
+ %max_int = pdl.attribute
+ %min_fp = pdl.attribute = 0.0 : f32
+ %max_fp = pdl.attribute
+ %matmul = pdl.operation "tosa.matmul"(%lhs, %rhs : !pdl.value, !pdl.value)
+ -> (%matmul_type : !pdl.type)
+ %element_type = pdl.type : f32
+ pdl.apply_native_constraint "checkTensorElementType"(%lhs_type, %element_type : !pdl.type, !pdl.type)
+ pdl.apply_native_constraint "checkTensorElementType"(%rhs_type, %element_type : !pdl.type, !pdl.type)
+ pdl.apply_native_constraint "checkTensorElementType"(%matmul_type, %element_type : !pdl.type, !pdl.type)
+
+ %matmul_result = pdl.result 0 of %matmul
+ %relu_type = pdl.type
+ %relu = pdl.operation "tosa.clamp"(%matmul_result : !pdl.value) {
+ "min_int" = %min_int, "max_int" = %max_int,
+ "min_fp" = %min_fp, "max_fp" = %max_fp}
+ -> (%relu_type : !pdl.type)
+
+ pdl.rewrite %matmul {
+ // The pattern above matched `%result`, `%lhs`, `%rhs` needed for the
+ // external function call. The values of `%M`, `%N` and `%K` need to
+ // be generated.
+ %one_val = pdl.attribute = 1 : index
+ %two_val = pdl.attribute = 2 : index
+ %index_type = pdl.type : index
+ %one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
+ %one = pdl.result 0 of %one_op
+ %two_op = pdl.operation "arith.constant" {"value" = %two_val} -> (%index_type : !pdl.type)
+ %two = pdl.result 0 of %two_op
+ %i32_type = pdl.type : i32
+ %m_op = pdl.operation "tensor.dim"(%lhs, %one : !pdl.value, !pdl.value)
+ %m = pdl.result 0 of %m_op
+ %n_op = pdl.operation "tensor.dim"(%rhs, %two : !pdl.value, !pdl.value)
+ %n = pdl.result 0 of %n_op
+ %k_op = pdl.operation "tensor.dim"(%lhs, %two : !pdl.value, !pdl.value)
+ %k = pdl.result 0 of %k_op
+ %m_i32_op = pdl.operation "arith.index_cast"(%m : !pdl.value) -> (%i32_type : !pdl.type)
+ %m_i32 = pdl.result 0 of %m_i32_op
+ %n_i32_op = pdl.operation "arith.index_cast"(%n : !pdl.value) -> (%i32_type : !pdl.type)
+ %n_i32 = pdl.result 0 of %n_i32_op
+ %k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
+ %k_i32 = pdl.result 0 of %k_i32_op
+
+ %replaced_values_dims = pdl.range : !pdl.range<value>
+ %input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
+ %replaced_value = pdl.result 0 of %relu
+ %replaced_values = pdl.range %replaced_value : !pdl.value
+ %other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
+
+ // The `rewriteAsFlowDispatch` is a rewrite function that allows
+ // converting the matched dag into a call to the external function call
+ // provided by a system plugin. The rewrite method expects the following
+ // arguments
+ // - the root of the matched DAG. This op will be erased after the call.
+ // - `fn_name` the name of the function that is provided externally
+ // (using a plugin).
+ // - `input_values` are values that are captures as the part of the match
+ // and are inputs to the match.
+ // - `replaced_values` are the values that are captured as part of the
+ // match and are replaced by the `flow.dispatch`. The `flow.dispatch`
+ // returns as many values as `replaced_values` (and of same type).
+ // - `replaced_values_dims` are the values for the dynamic dimensions of
+ // all the `tensor` values in `replaced_values`. For matches that could
+ // be static or dynamic, it should be assumed that the shape is dynamic
+ // and the value needs to be passed to the rewrite function.
+ // - `other_operands` same as `input_values`, but kept separate to allow
+ // flexibility of where the results are passed through the ABI boundary.
+ %fn_name = pdl.attribute = "mlp_external"
+ pdl.apply_native_rewrite "rewriteAsFlowDispatch"(
+ %relu, %fn_name, %input_values, %replaced_values, %replaced_values_dims, %other_operands
+ : !pdl.operation, !pdl.attribute, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
+ }
+}
+
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/tosa_to_linalg_ext.mlir b/compiler/plugins/input/TOSA/InputConversion/test/tosa_to_linalg_ext.mlir
new file mode 100644
index 0000000..e944830
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/tosa_to_linalg_ext.mlir
@@ -0,0 +1,83 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-tosa-to-linalg-ext))" --verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @scatter_static
+func.func @scatter_static(%arg0 : tensor<1x4x5xf32>, %arg1 : tensor<1x2xi32>, %arg2 : tensor<1x2x5xf32>) -> tensor<1x4x5xf32> {
+ // CHECK: %[[EXPANDIDX:.+]] = tensor.expand_shape %arg1
+ // CHECK-SAME{literal}: [[0], [1, 2]] : tensor<1x2xi32> into tensor<1x2x1xi32>
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x1xi32>
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : i32) outs(%[[EMPTY]] : tensor<1x2x1xi32>)
+ // CHECK-DAG: %[[CONCAT:.+]] = tosa.concat %[[FILL]], %[[EXPANDIDX]] {axis = 2 : i32}
+ // CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
+ // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<1x2x2xi32> into tensor<2x2xi32>
+ // CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
+ // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<1x2x5xf32> into tensor<2x5xf32>
+ // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
+ // CHECK-SAME: ins(%[[COLLAPSE_UPD]], %[[COLLAPSE_IDX]] : tensor<2x5xf32>, tensor<2x2xi32>)
+ // CHECK-SAME: outs(%arg0 : tensor<1x4x5xf32>)
+ // CHECK: ^bb0(%[[UPD:.+]]: f32, %{{.+}}: f32):
+ // CHECK: iree_linalg_ext.yield %[[UPD]]
+ // CHECK: } -> tensor<1x4x5xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<1x4x5xf32>, tensor<1x2xi32>, tensor<1x2x5xf32>) -> (tensor<1x4x5xf32>)
+
+ // CHECK: return %[[SCATTER]]
+ return %0 : tensor<1x4x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_static_batched
+func.func @scatter_static_batched(%arg0 : tensor<2x4x5xf32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<2x2x5xf32>) -> tensor<2x4x5xf32> {
+ // CHECK: %[[EXPANDIDX:.+]] = tensor.expand_shape %arg1
+ // CHECK-SAME{literal}: [[0], [1, 2]] : tensor<2x2xi32> into tensor<2x2x1xi32>
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x1xi32>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#map, #map]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%[[EXPANDIDX]] : tensor<2x2x1xi32>)
+ // CHECK-SAME: outs(%[[EMPTY:.+]] : tensor<2x2x1xi32>) {
+ // CHECK: %[[IDX:.+]] = linalg.index 0 : index
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] : index to i32
+ // CHECK: linalg.yield %[[CAST]] : i32
+ // CHECK: %[[CONCAT:.+]] = tosa.concat %[[GENERIC]], %[[EXPANDIDX]] {axis = 2 : i32}
+ // CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
+ // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<2x2x2xi32> into tensor<4x2xi32>
+ // CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
+ // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<2x2x5xf32> into tensor<4x5xf32>
+ // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
+ // CHECK-SAME: ins(%[[COLLAPSE_UPD]], %[[COLLAPSE_IDX]] : tensor<4x5xf32>, tensor<4x2xi32>)
+ // CHECK-SAME: outs(%arg0 : tensor<2x4x5xf32>)
+ // CHECK: ^bb0(%[[UPD:.+]]: f32, %{{.+}}: f32):
+ // CHECK: iree_linalg_ext.yield %[[UPD]]
+ // CHECK: } -> tensor<2x4x5xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xf32>, tensor<2x2xi32>, tensor<2x2x5xf32>) -> (tensor<2x4x5xf32>)
+
+ // CHECK: return %[[SCATTER]]
+ return %0 : tensor<2x4x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_dynamic
+func.func @scatter_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xi32>, %arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ // CHECK-DAG: %[[EXPAND:.+]] = tensor.expand_shape %arg1
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPAND]], %[[C0]] : tensor<?x?x1xi32>
+ // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPAND]], %[[C1]] : tensor<?x?x1xi32>
+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x1xi32>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: ins(%[[EXPAND]] : tensor<?x?x1xi32>) outs(%[[EMPTY]] : tensor<?x?x1xi32>) {
+ // CHECK: %[[CONCAT:.+]] = tosa.concat %[[GENERIC]], %[[EXPAND]] {axis = 2 : i32}
+ // CHECK: %[[COLLAPSE_IDX:.+]] = tensor.collapse_shape %[[CONCAT]]
+ // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<?x?x2xi32> into tensor<?x2xi32>
+ // CHECK: %[[COLLAPSE_UPD:.+]] = tensor.collapse_shape %arg2
+ // CHECK-SAME{literal}: [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
+ // CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+ // CHECK-SAME: ins(%[[COLLAPSE_UPD]], %[[COLLAPSE_IDX]] : tensor<?x?xf32>, tensor<?x2xi32>) outs(%arg0 : tensor<?x?x?xf32>)
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
+
+ // CHECK: return %[[SCATTER]]
+ return %0 : tensor<?x?x?xf32>
+}
+
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/verify_compiler_tosa_input_legality.mlir b/compiler/plugins/input/TOSA/InputConversion/test/verify_compiler_tosa_input_legality.mlir
new file mode 100644
index 0000000..bf1564f
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/verify_compiler_tosa_input_legality.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt --split-input-file --iree-tosa-verify-compiler-input-legality --verify-diagnostics %s
+// -verify-diagnostics
+
+// expected-error@+1 {{one or more illegal operations were found in the compiler input}}
+module {
+func.func @illegal_tosa(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ // expected-note@+1 {{failed to legalize operation 'tosa.add' that was explicitly marked illegal}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+}