Migrate TOSA input conversion to a compiler plugin. (#15495)
Progress on https://github.com/openxla/iree/issues/15468
Using the plugin API lets us avoid plumbing `#ifdef
IREE_HAVE_TOSA_INPUT` across the source tree. The `IREE_INPUT_TOSA`
CMake define now just chooses whether to include the plugin or not.
Future PRs will migrate other components like the StableHLO input
conversion. I chose to start with the TOSA conversion because it has
fewer files.
## Build system details
This is _mostly_ a code move. The Bazel and CMake configurations were a
little tricky though. Note that some CMake files here were authored
manually, not via bazel_to_cmake.
The Torch input conversion has files located at
`compiler/plugins/input/Torch/torch-iree/**`, with includes rooted on
`torch-iree` (e.g. `#include "torch-iree/InputConversion/Passes.h"`).
However, Bazel wants include paths to match the (unambiguous)
`WORKSPACE`-relative paths (e.g. `#include
"compiler/plugins/input/Torch/torch-iree/InputConversion/Passes.h"`).
Solutions for working around that in Bazel are [documented
here](https://bazel.build/tutorials/cpp-use-cases#add-include-paths).
Similarly, the Torch input conversion roots its CMake targets at
`torch-iree::InputConversion`. In cases where we use relative targets
that isn't an issue, but bazel_to_cmake does not understand that for
absolute target labels.
We have a few options:
1. Use WORKSPACE-relative paths and generate CMake files from Bazel
files
2. (this PR) Use plugin-relative paths and manually author a few CMake
files
3. Use plugin-relative paths and teach bazel_to_cmake about include
copts
4. Use something like `//compiler/src:defs` to propagate include paths
(I tried this a few ways already)
5. Use relative include paths (`./...`)
## Future work
* This new plugin depends on `iree::compiler::Dialect::Flow::Transforms`
for `IREE::Flow::createStripSignednessPass()`. I think this is the only
use of that pass, so it could be moved in to the plugin.
* Mixed input dialect conversion is worth trying now
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index a5ebd13..39dc8a1 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -73,8 +73,8 @@
/compiler/src/iree/compiler/Dialect/Vulkan/ @antiagainst
/compiler/src/iree/compiler/GlobalOptimization/ @hanhanW
/compiler/src/iree/compiler/InputConversion/ @MaheshRavishankar @stellaraccident
-/compiler/src/iree/compiler/InputConversion/MHLO @hanhanW @MaheshRavishankar @rsuderman
-/compiler/src/iree/compiler/InputConversion/TOSA @MaheshRavishankar @rsuderman
+/compiler/src/iree/compiler/InputConversion/StableHLO/ @hanhanW @MaheshRavishankar @rsuderman
+/compiler/plugins/input/TOSA/ @MaheshRavishankar @rsuderman
# Runtime
/runtime/src/iree/ @benvanik
diff --git a/compiler/plugins/input/TOSA/BUILD.bazel b/compiler/plugins/input/TOSA/BUILD.bazel
new file mode 100644
index 0000000..522ca5d
--- /dev/null
+++ b/compiler/plugins/input/TOSA/BUILD.bazel
@@ -0,0 +1,11 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/compiler/plugins/input/TOSA/CMakeLists.txt b/compiler/plugins/input/TOSA/CMakeLists.txt
new file mode 100644
index 0000000..e5c29d4
--- /dev/null
+++ b/compiler/plugins/input/TOSA/CMakeLists.txt
@@ -0,0 +1,22 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+set(IREE_PACKAGE_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}")
+set(IREE_PACKAGE_ROOT_PREFIX "")
+set(IREE_COMPILER_TABLEGEN_INCLUDE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}")
+
+add_library(tosa-iree_compiler_defs INTERFACE)
+target_include_directories(tosa-iree_compiler_defs
+ INTERFACE
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${CMAKE_CURRENT_BINARY_DIR}
+)
+
+# Configures all iree_cc_* targets to take this implicit dep,
+# which provides common includes and copts for the tree.
+set(IREE_IMPLICIT_DEFS_CC_DEPS tosa-iree_compiler_defs)
+
+add_subdirectory(tosa-iree)
diff --git a/compiler/plugins/input/TOSA/tosa-iree/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/BUILD.bazel
new file mode 100644
index 0000000..1e8d443
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/BUILD.bazel
@@ -0,0 +1,39 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_compiler_register_plugin")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_register_plugin(
+ plugin_id = "input_tosa",
+ target = ":registration",
+)
+
+iree_compiler_cc_library(
+ name = "registration",
+ srcs = [
+ "PluginRegistration.cpp",
+ ],
+ copts = ["-Icompiler/plugins/input/TOSA"],
+ deps = [
+ "//compiler/plugins/input/TOSA/tosa-iree/InputConversion",
+ "//compiler/src/iree/compiler/PluginAPI",
+ "@llvm-project//mlir:ConversionPasses",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:TosaDialect",
+ "@llvm-project//mlir:TosaToArith",
+ "@llvm-project//mlir:TosaToLinalg",
+ "@llvm-project//mlir:TosaToSCF",
+ "@llvm-project//mlir:TosaToTensor",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/compiler/plugins/input/TOSA/tosa-iree/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/CMakeLists.txt
new file mode 100644
index 0000000..bfa618f
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/CMakeLists.txt
@@ -0,0 +1,27 @@
+iree_add_all_subdirs()
+
+iree_compiler_register_plugin(
+ PLUGIN_ID
+ input_tosa
+ TARGET
+ ::registration
+)
+
+iree_cc_library(
+ NAME
+ registration
+ SRCS
+ "PluginRegistration.cpp"
+ DEPS
+ MLIRIR
+ MLIRPass
+ MLIRTosaDialect
+ MLIRTosaToArith
+ MLIRTosaToLinalg
+ MLIRTosaToSCF
+ MLIRTosaToTensor
+ MLIRTransforms
+ iree::compiler::PluginAPI
+ tosa-iree::InputConversion::InputConversion
+ PUBLIC
+)
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
similarity index 89%
rename from compiler/src/iree/compiler/InputConversion/TOSA/BUILD.bazel
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
index cdd9ed6..0feea75 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
@@ -34,6 +34,10 @@
"Passes.h",
"Passes.h.inc",
],
+ copts = [
+ "-Icompiler/plugins/input/TOSA",
+ "-I$(GENDIR)/compiler/plugins/input/TOSA",
+ ],
deps = [
":PassesIncGen",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
@@ -47,7 +51,7 @@
)
iree_compiler_cc_library(
- name = "TOSA",
+ name = "InputConversion",
srcs = [
"Converti48Toi64.cpp",
"Passes.cpp",
@@ -57,8 +61,9 @@
hdrs = [
"Passes.h",
],
- defines = [
- "IREE_HAVE_TOSA_INPUT",
+ copts = [
+ "-Icompiler/plugins/input/TOSA",
+ "-I$(GENDIR)/compiler/plugins/input/TOSA",
],
deps = [
":PassHeaders",
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
similarity index 83%
rename from compiler/src/iree/compiler/InputConversion/TOSA/CMakeLists.txt
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
index e7e8f0e..68e36ce 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
@@ -1,6 +1,6 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# compiler/src/iree/compiler/InputConversion/TOSA/BUILD.bazel #
+# compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
@@ -22,6 +22,9 @@
iree_cc_library(
NAME
PassHeaders
+ COPTS
+ "-Icompiler/plugins/input/TOSA"
+ "-I$(GENDIR)/compiler/plugins/input/TOSA"
HDRS
"PassDetail.h"
"Passes.h"
@@ -40,7 +43,10 @@
iree_cc_library(
NAME
- TOSA
+ InputConversion
+ COPTS
+ "-Icompiler/plugins/input/TOSA"
+ "-I$(GENDIR)/compiler/plugins/input/TOSA"
HDRS
"Passes.h"
SRCS
@@ -67,14 +73,7 @@
MLIRTransforms
iree::compiler::Dialect::Flow::Transforms
iree::compiler::InputConversion::Common
- DEFINES
- "IREE_HAVE_TOSA_INPUT"
PUBLIC
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
-# TODO: For some reason, these dependencies are not being added automatically.
-add_dependencies(
- iree_compiler_InputConversion_TOSA_PassHeaders
- iree_compiler_InputConversion_TOSA_PassesIncGen
-)
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Converti48Toi64.cpp
similarity index 98%
rename from compiler/src/iree/compiler/InputConversion/TOSA/Converti48Toi64.cpp
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/Converti48Toi64.cpp
index 3ccc620..756e845 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/Converti48Toi64.cpp
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Converti48Toi64.cpp
@@ -10,7 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/InputConversion/TOSA/PassDetail.h"
+#include "tosa-iree/InputConversion/PassDetail.h"
+
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/PassDetail.h b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/PassDetail.h
similarity index 73%
rename from compiler/src/iree/compiler/InputConversion/TOSA/PassDetail.h
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/PassDetail.h
index 8987600..f763344 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/PassDetail.h
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/PassDetail.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_INPUTCONVERSION_TOSA_PASSDETAIL_H_
-#define IREE_COMPILER_INPUTCONVERSION_TOSA_PASSDETAIL_H_
+#ifndef TOSA_IREE_INPUTCONVERSION_PASSDETAIL_H_
+#define TOSA_IREE_INPUTCONVERSION_PASSDETAIL_H_
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -19,9 +19,9 @@
namespace iree_compiler {
#define GEN_PASS_CLASSES
-#include "iree/compiler/InputConversion/TOSA/Passes.h.inc"
+#include "tosa-iree/InputConversion/Passes.h.inc"
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_INPUTCONVERSION_TOSA_PASSDETAIL_H_
+#endif // TOSA_IREE_INPUTCONVERSION_PASSDETAIL_H_
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/Passes.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
similarity index 92%
rename from compiler/src/iree/compiler/InputConversion/TOSA/Passes.cpp
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
index ae907a7..532cb9d 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/Passes.cpp
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/InputConversion/TOSA/Passes.h"
+#include "tosa-iree/InputConversion/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
@@ -60,6 +60,8 @@
passManager.addNestedPass<func::FuncOp>(tosa::createTosaToArith());
passManager.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());
+ // TODO(scotttodd): move IREE::Flow::createStripSignednessPass into plugin
+ // (should in-tree plugins even depend on other in-tree code?)
passManager.addNestedPass<func::FuncOp>(
IREE::Flow::createStripSignednessPass());
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
@@ -78,7 +80,7 @@
namespace {
#define GEN_PASS_REGISTRATION
-#include "iree/compiler/InputConversion/TOSA/Passes.h.inc" // IWYU pragma: export
+#include "tosa-iree/InputConversion/Passes.h.inc" // IWYU pragma: export
} // namespace
void registerTOSAConversionPasses() {
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/Passes.h b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
similarity index 91%
rename from compiler/src/iree/compiler/InputConversion/TOSA/Passes.h
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
index 1508e45..4eb693a 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/Passes.h
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES_H_
-#define IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES_H_
+#ifndef TOSA_IREE_INPUTCONVERSION_PASSES_H_
+#define TOSA_IREE_INPUTCONVERSION_PASSES_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
@@ -50,4 +50,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES_H_
+#endif // TOSA_IREE_INPUTCONVERSION_PASSES_H_
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/Passes.td b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
similarity index 87%
rename from compiler/src/iree/compiler/InputConversion/TOSA/Passes.td
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
index 5214342..3a6bbf9 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/Passes.td
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES
-#define IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES
+#ifndef TOSA_IREE_INPUTCONVERSION_PASSES
+#define TOSA_IREE_INPUTCONVERSION_PASSES
include "mlir/Pass/PassBase.td"
@@ -33,4 +33,4 @@
let constructor = "mlir::iree_compiler::createConverti48Toi64()";
}
-#endif // IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES
+#endif // TOSA_IREE_INPUTCONVERSION_PASSES
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/TosaToLinalgExt.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/TosaToLinalgExt.cpp
similarity index 98%
rename from compiler/src/iree/compiler/InputConversion/TOSA/TosaToLinalgExt.cpp
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/TosaToLinalgExt.cpp
index 1da5e3d..953aea0 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/TosaToLinalgExt.cpp
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/TosaToLinalgExt.cpp
@@ -12,8 +12,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/InputConversion/TOSA/PassDetail.h"
-#include "iree/compiler/InputConversion/TOSA/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -21,6 +19,8 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "tosa-iree/InputConversion/PassDetail.h"
+#include "tosa-iree/InputConversion/Passes.h"
using namespace mlir;
using namespace mlir::tosa;
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/VerifyCompilerTOSAInputLegality.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/VerifyCompilerTOSAInputLegality.cpp
similarity index 95%
rename from compiler/src/iree/compiler/InputConversion/TOSA/VerifyCompilerTOSAInputLegality.cpp
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/VerifyCompilerTOSAInputLegality.cpp
index 719c015..191a225 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/VerifyCompilerTOSAInputLegality.cpp
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/VerifyCompilerTOSAInputLegality.cpp
@@ -4,12 +4,12 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/InputConversion/TOSA/PassDetail.h"
-#include "iree/compiler/InputConversion/TOSA/Passes.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "tosa-iree/InputConversion/PassDetail.h"
+#include "tosa-iree/InputConversion/Passes.h"
namespace mlir {
namespace iree_compiler {
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/test/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
similarity index 95%
rename from compiler/src/iree/compiler/InputConversion/TOSA/test/BUILD.bazel
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
index 44ca9fd..234b517 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/test/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
@@ -4,8 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-# Tests for common transforms.
-
load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/test/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
similarity index 91%
rename from compiler/src/iree/compiler/InputConversion/TOSA/test/CMakeLists.txt
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
index f4b124d..c226bc4 100644
--- a/compiler/src/iree/compiler/InputConversion/TOSA/test/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
@@ -1,6 +1,6 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# compiler/src/iree/compiler/InputConversion/TOSA/test/BUILD.bazel #
+# compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel #
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/test/convert_i48_to_i64.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
similarity index 100%
rename from compiler/src/iree/compiler/InputConversion/TOSA/test/convert_i48_to_i64.mlir
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/tosa_to_linalg_ext.mlir
similarity index 100%
rename from compiler/src/iree/compiler/InputConversion/TOSA/test/tosa_to_linalg_ext.mlir
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/tosa_to_linalg_ext.mlir
diff --git a/compiler/src/iree/compiler/InputConversion/TOSA/test/verify_compiler_tosa_input_legality.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/verify_compiler_tosa_input_legality.mlir
similarity index 100%
rename from compiler/src/iree/compiler/InputConversion/TOSA/test/verify_compiler_tosa_input_legality.mlir
rename to compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/verify_compiler_tosa_input_legality.mlir
diff --git a/compiler/plugins/input/TOSA/tosa-iree/PluginRegistration.cpp b/compiler/plugins/input/TOSA/tosa-iree/PluginRegistration.cpp
new file mode 100644
index 0000000..5235ba9
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/PluginRegistration.cpp
@@ -0,0 +1,80 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/PluginAPI/Client.h"
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/PassManager.h"
+
+#include "tosa-iree/InputConversion/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// TOSA (Tensor Operator Set Architecture) support plugin.
+// * https://www.mlplatform.org/tosa
+// * https://mlir.llvm.org/docs/Dialects/TOSA/
+//
+// The TOSA plugin provides dialects, passes and opt-in options.
+// Therefore, it is appropriate for default activation.
+struct TOSASession
+ : public PluginSession<TOSASession, EmptyPluginOptions,
+ PluginActivationPolicy::DefaultActivated> {
+ static void registerPasses() {
+ registerTOSAConversionPasses();
+ registerTosaToArith();
+ registerTosaToLinalg();
+ registerTosaToTensor();
+ }
+
+ void onRegisterDialects(DialectRegistry ®istry) override {
+ registry.insert<tosa::TosaDialect>();
+ }
+
+ bool extendCustomInputConversionPassPipeline(
+ OpPassManager &passManager, std::string_view typeMnemonic) override {
+ if (typeMnemonic == "tosa") {
+ buildTOSAInputConversionPassPipeline(passManager);
+ return true;
+ }
+
+ return false;
+ }
+
+ void populateCustomInputConversionTypes(StringSet<> &typeMnemonics) override {
+ typeMnemonics.insert("tosa");
+ }
+
+ void populateDetectedCustomInputConversionTypes(
+ ModuleOp &module, StringSet<> &typeMnemonics) override {
+ auto *ctx = module.getContext();
+ const Dialect *tosaDialect = ctx->getLoadedDialect("tosa");
+
+ module.walk([&](Operation *op) {
+ Dialect *d = op->getDialect();
+ if (d == tosaDialect) {
+ typeMnemonics.insert("tosa");
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ }
+};
+
+} // namespace
+
+} // namespace iree_compiler
+} // namespace mlir
+
+extern "C" bool iree_register_compiler_plugin_input_tosa(
+ mlir::iree_compiler::PluginRegistrar *registrar) {
+ registrar->registerPlugin<::mlir::iree_compiler::TOSASession>("input_tosa");
+ return true;
+}
diff --git a/compiler/plugins/iree_compiler_plugin_group.cmake b/compiler/plugins/iree_compiler_plugin_group.cmake
index 1b64c11..df17ecd 100644
--- a/compiler/plugins/iree_compiler_plugin_group.cmake
+++ b/compiler/plugins/iree_compiler_plugin_group.cmake
@@ -8,6 +8,10 @@
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/input/Torch input/Torch)
endif()
+if(IREE_INPUT_TOSA)
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/input/TOSA input/TOSA)
+endif()
+
if(IREE_TARGET_BACKEND_CUDA)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/target/CUDA target/CUDA)
endif()
diff --git a/compiler/src/iree/compiler/InputConversion/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/CMakeLists.txt
index f28c068..9467020 100644
--- a/compiler/src/iree/compiler/InputConversion/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/CMakeLists.txt
@@ -9,6 +9,3 @@
if(IREE_INPUT_STABLEHLO)
add_subdirectory(StableHLO)
endif()
-if(IREE_INPUT_TOSA)
- add_subdirectory(TOSA)
-endif()
diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
index 4f751c7..c6b8948 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
@@ -19,9 +18,6 @@
#include "iree/compiler/InputConversion/StableHLO/Passes.h"
#include "stablehlo/dialect/StablehloOps.h"
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
-#include "iree/compiler/InputConversion/TOSA/Passes.h"
-#endif // IREE_HAVE_TOSA_INPUT
namespace mlir::iree_compiler {
namespace {
@@ -47,8 +43,6 @@
bool hasStableHLO = false;
// - XLA import features.
bool hasTuples = false;
- // TOSA features.
- bool hasTOSA = false;
};
static void populateHloFeatures(Operation *op, InputFeatures &features) {
@@ -89,17 +83,12 @@
static void populateFeatures(Operation *op, const Dialect *chloDialect,
const Dialect *stablehloDialect,
- const Dialect *tosaDialect,
InputFeatures &features) {
Dialect *d = op->getDialect();
if (d == stablehloDialect || d == chloDialect) {
features.hasStableHLO = true;
return populateHloFeatures(op, features);
}
- if (d == tosaDialect) {
- features.hasTOSA = true;
- return;
- }
}
void AutoInputConversionPipelinePass::runOnOperation() {
@@ -152,23 +141,15 @@
InputFeatures features;
const Dialect *chloDialect = context->getLoadedDialect("chlo");
const Dialect *stablehloDialect = context->getLoadedDialect("stablehlo");
- const Dialect *tosaDialect = context->getLoadedDialect("tosa");
- if (!chloDialect && !stablehloDialect && !tosaDialect) {
+ if (!chloDialect && !stablehloDialect) {
return;
}
- auto res = module.walk([&](Operation *op) {
- populateFeatures(op, chloDialect, stablehloDialect, tosaDialect, features);
- if (features.hasStableHLO && features.hasTOSA) {
- module.emitError("not yet implemented mixture of *HLO and TOSA");
- return WalkResult::interrupt();
- }
+ module.walk([&](Operation *op) {
+ populateFeatures(op, chloDialect, stablehloDialect, features);
return WalkResult::advance();
});
- if (res.wasInterrupted()) {
- return signalPassFailure();
- }
- if (!features.hasStableHLO && !features.hasTOSA) {
+ if (!features.hasStableHLO) {
return;
}
@@ -187,11 +168,6 @@
}
}
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
- if (features.hasTOSA) {
- buildTOSAInputConversionPassPipeline(pm);
- }
-#endif // IREE_HAVE_TOSA_INPUT
if (failed(runPipeline(pm, module))) {
signalPassFailure();
@@ -230,9 +206,9 @@
stablehlo::buildStableHLOXLAInputConversionPassPipeline);
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
- appendPipelineDialects(buildTOSAInputConversionPassPipeline);
-#endif // IREE_HAVE_TOSA_INPUT
+ if (pipelineExtensions) {
+ pipelineExtensions->registerDialects(registry);
+ }
if (pipelineExtensions) {
pipelineExtensions->registerDialects(registry);
diff --git a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
index 0abea9a..af5d141 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel
@@ -95,14 +95,12 @@
":PassHeaders",
":PassesIncGen",
"//compiler/src/iree/compiler/InputConversion/StableHLO",
- "//compiler/src/iree/compiler/InputConversion/TOSA",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
- "@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
"@stablehlo//:stablehlo_ops",
],
diff --git a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
index 628089f..f0d9908 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt
@@ -9,9 +9,6 @@
if(IREE_INPUT_STABLEHLO)
list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::StableHLO)
endif()
-if(IREE_INPUT_TOSA)
- list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::TOSA)
-endif()
iree_add_all_subdirs()
diff --git a/compiler/src/iree/compiler/Pipelines/BUILD.bazel b/compiler/src/iree/compiler/Pipelines/BUILD.bazel
index e022e6d..df24320 100644
--- a/compiler/src/iree/compiler/Pipelines/BUILD.bazel
+++ b/compiler/src/iree/compiler/Pipelines/BUILD.bazel
@@ -18,7 +18,6 @@
hdrs = ["Options.h"],
deps = [
"//compiler/src/iree/compiler/InputConversion/StableHLO",
- "//compiler/src/iree/compiler/InputConversion/TOSA",
"//compiler/src/iree/compiler/Utils",
],
)
@@ -49,7 +48,6 @@
"//compiler/src/iree/compiler/InputConversion/Common",
"//compiler/src/iree/compiler/InputConversion/Common:AutoInputConversionPipeline",
"//compiler/src/iree/compiler/InputConversion/StableHLO",
- "//compiler/src/iree/compiler/InputConversion/TOSA",
"//compiler/src/iree/compiler/Modules/HAL/Inline/Transforms",
"//compiler/src/iree/compiler/Modules/HAL/Loader/Transforms",
"//compiler/src/iree/compiler/Preprocessing:Passes",
diff --git a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt
index 8b838b8..010308b 100644
--- a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt
@@ -10,9 +10,6 @@
if(IREE_INPUT_STABLEHLO)
list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::StableHLO)
endif()
-if(IREE_INPUT_TOSA)
- list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::TOSA)
-endif()
iree_cc_library(
NAME
diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp
index 5d4f8dc..01cdc48 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Options.cpp
@@ -45,12 +45,12 @@
" =stablehlo - Legalize from StableHLO ops.\n"
" =stablehlo_xla - Legalize from StableHLO ops (with XLA cleanup preprocessing).\n"
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
- " =tosa - Legalize from TOSA ops.\n"
-#endif // IREE_HAVE_TOSA_INPUT
// NOTE: The plugin system does not have a good way to populate CL help
-// messages, so we err on the side of being helpful and populating Torch
+// messages, so we err on the side of being helpful and populating plugin
// options here, even though it is a layering violation.
+#ifdef IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_TOSA
+ " =tosa - Legalize from TOSA ops.\n"
+#endif // IREE_COMPILER_PLUGIN_HAVE_STATIC_INPUT_TOSA
#ifdef IREE_COMPILER_PLUGIN_HAVE_STATIC_TORCH_IREE
" =tm_tensor - Legalize a subset of Torch input ops.\n"
" =torch - Legalize from the 'torch' dialect.\n"
@@ -89,10 +89,6 @@
} else if (inputTypeMnemonic == "stablehlo_xla") {
return Type::stablehlo_xla;
#endif
-#ifdef IREE_HAVE_TOSA_INPUT
- } else if (inputTypeMnemonic == "tosa") {
- return Type::tosa;
-#endif
} else {
return Type::plugin;
}
diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h
index 0909ab8..e58cbff 100644
--- a/compiler/src/iree/compiler/Pipelines/Options.h
+++ b/compiler/src/iree/compiler/Pipelines/Options.h
@@ -47,10 +47,6 @@
// preprocessing, e.g., flattening of tuples.
stablehlo_xla,
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
- // Legalizes input defined over TOSA ops.
- tosa,
-#endif // IREE_HAVE_TOSA_INPUT
};
// The flag value is captured into spec by the CL system and it must be
// interpreted by parseInputTypeSpec.
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
index 1f0c0a5..6d2ebcb 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -23,9 +23,6 @@
#ifdef IREE_HAVE_STABLEHLO_INPUT
#include "iree/compiler/InputConversion/StableHLO/Passes.h"
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
-#include "iree/compiler/InputConversion/TOSA/Passes.h"
-#endif // IREE_HAVE_TOSA_INPUT
namespace mlir {
namespace iree_compiler {
@@ -106,11 +103,6 @@
stablehloOptions);
break;
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
- case InputDialectOptions::Type::tosa:
- buildTOSAInputConversionPassPipeline(passManager);
- break;
-#endif // IREE_HAVE_TOSA_INPUT
}
buildCommonInputConversionPassPipeline(passManager);
IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Input");
diff --git a/compiler/src/iree/compiler/PluginAPI/Config/BUILD.bazel b/compiler/src/iree/compiler/PluginAPI/Config/BUILD.bazel
index dab783d..16f9943 100644
--- a/compiler/src/iree/compiler/PluginAPI/Config/BUILD.bazel
+++ b/compiler/src/iree/compiler/PluginAPI/Config/BUILD.bazel
@@ -21,6 +21,8 @@
cmd = (
"echo '" +
"HANDLE_PLUGIN_ID(hal_target_cuda)\n" +
+ "HANDLE_PLUGIN_ID(input_tosa)\n" +
+ # Samples
"HANDLE_PLUGIN_ID(example)\n" +
"HANDLE_PLUGIN_ID(simple_io_sample)\n" +
"' > $@"
@@ -39,6 +41,7 @@
# generates its deps from the environment.
# For now, we just hard include all in-tree plugins.
"//compiler/plugins/target/CUDA",
+ "//compiler/plugins/input/TOSA/tosa-iree:registration",
"//samples/compiler_plugins/example:registration",
"//samples/compiler_plugins/simple_io_sample:registration",
],
diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel
index 5d9ec18..a79bea0 100644
--- a/compiler/src/iree/compiler/Tools/BUILD.bazel
+++ b/compiler/src/iree/compiler/Tools/BUILD.bazel
@@ -38,10 +38,8 @@
deps = [
"//compiler/src/iree/compiler/InputConversion/Common",
"//compiler/src/iree/compiler/InputConversion/StableHLO",
- "//compiler/src/iree/compiler/InputConversion/TOSA",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:IR",
- "@llvm-project//mlir:TosaDialect",
"@stablehlo//:chlo_ops",
"@stablehlo//:stablehlo_ops",
],
diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt
index 6553405..61a1609 100644
--- a/compiler/src/iree/compiler/Tools/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt
@@ -47,11 +47,6 @@
list(APPEND IREE_INPUT_DEPS ChloOps)
list(APPEND IREE_INPUT_DEPS StablehloOps)
endif()
-if(IREE_INPUT_TOSA)
- list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::TOSA)
- list(APPEND IREE_INPUT_DEPS MLIRTosaDialect)
- list(APPEND IREE_INPUT_DEPS MLIRTosaTransforms)
-endif()
iree_cc_library(
NAME
@@ -226,4 +221,3 @@
COPTS
${IREE_VERSION_TARGET_COPTS}
)
-
diff --git a/compiler/src/iree/compiler/Tools/init_input_dialects.cc b/compiler/src/iree/compiler/Tools/init_input_dialects.cc
index 651e287..2ab2f7c 100644
--- a/compiler/src/iree/compiler/Tools/init_input_dialects.cc
+++ b/compiler/src/iree/compiler/Tools/init_input_dialects.cc
@@ -10,9 +10,6 @@
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#endif // IREE_HAVE_TOSA_INPUT
namespace mlir {
namespace iree_compiler {
@@ -21,9 +18,6 @@
#ifdef IREE_HAVE_STABLEHLO_INPUT
registry.insert<mlir::chlo::ChloDialect, mlir::stablehlo::StablehloDialect>();
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
- registry.insert<tosa::TosaDialect>();
-#endif // IREE_HAVE_TOSA_INPUT
}
} // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Tools/init_input_passes.cc b/compiler/src/iree/compiler/Tools/init_input_passes.cc
index 8733758..1259d2c 100644
--- a/compiler/src/iree/compiler/Tools/init_input_passes.cc
+++ b/compiler/src/iree/compiler/Tools/init_input_passes.cc
@@ -11,11 +11,6 @@
#ifdef IREE_HAVE_STABLEHLO_INPUT
#include "iree/compiler/InputConversion/StableHLO/Passes.h"
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
-#include "iree/compiler/InputConversion/TOSA/Passes.h"
-#include "mlir/Conversion/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#endif // IREE_HAVE_TOSA_INPUT
namespace mlir {
namespace iree_compiler {
@@ -26,12 +21,6 @@
#ifdef IREE_HAVE_STABLEHLO_INPUT
stablehlo::registerStableHLOConversionPasses();
#endif // IREE_HAVE_STABLEHLO_INPUT
-#ifdef IREE_HAVE_TOSA_INPUT
- registerTOSAConversionPasses();
- registerTosaToArithPass();
- registerTosaToLinalgPass();
- registerTosaToTensorPass();
-#endif // IREE_HAVE_TOSA_INPUT
}
} // namespace iree_compiler