Add Linalg to SPIRV pipeline to IREE.
Add an experimental pass pipeline to convert Linalg Generic Op to
SPIR-V.
PiperOrigin-RevId: 288769570
diff --git a/iree/compiler/Translation/SPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/CMakeLists.txt
index 4c551b7..9e0b7a9 100644
--- a/iree/compiler/Translation/SPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/CMakeLists.txt
@@ -17,3 +17,4 @@
add_subdirectory(Passes)
add_subdirectory(ReductionCodegen)
add_subdirectory(XLAToSPIRV)
+add_subdirectory(LinalgToSPIRV)
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
new file mode 100644
index 0000000..ddc9fd1
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
@@ -0,0 +1,43 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "LinalgToSPIRV",
+ srcs = [
+ "LinalgToSPIRV.cpp",
+ ],
+ deps = [
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:AffineOps",
+ "@llvm-project//mlir:AffineToStandardTransforms",
+ "@llvm-project//mlir:GPUToSPIRVTransforms",
+ "@llvm-project//mlir:GPUTransforms",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Linalg",
+ "@llvm-project//mlir:LoopsToGPUPass",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SPIRVDialect",
+ "@llvm-project//mlir:SPIRVLowering",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:StandardToSPIRVConversions",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:Transforms",
+ ],
+ alwayslink = 1,
+)
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
new file mode 100644
index 0000000..a0955e8
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
@@ -0,0 +1,36 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_cc_library(
+ NAME
+ LinalgToSPIRV
+ HDRS
+ SRCS
+ "LinalgToSPIRV.cpp"
+ DEPS
+ LLVMSupport
+ MLIRAffineOps
+ MLIRAffineToStandard
+ MLIRGPUtoSPIRVTransforms
+ MLIRGPU
+ MLIRIR
+ MLIRLinalg
+ MLIRLoopsToGPU
+ MLIRPass
+ MLIRSPIRV
+ MLIRStandardOps
+ MLIRStandardToSPIRVTransforms
+ MLIRSupport
+ MLIRTransforms
+)
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp
new file mode 100644
index 0000000..75d60be
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -0,0 +1,104 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LinalgToSPIRV.cpp - Linalg dialect to SPIR-V dialect----------------===//
+//
+// Implementation of conversion from Linalg To SPIRV
+//
+//===----------------------------------------------------------------------===//
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
+#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+struct LinalgToSPIRVPassOptions
+ : public PassPipelineOptions<LinalgToSPIRVPassOptions> {
+ ListOption<int64_t> numWorkGroups{
+ *this, "num-workgroups",
+ llvm::cl::desc(
+ "Number of workgroups in the SPIR-V module for x, followed by y, "
+ "followed by z dimension of the dispatch (others will be ignored)"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+ ListOption<int64_t> workGroupSize{
+ *this, "workgroup-size",
+ llvm::cl::desc(
+ "Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
+ "by z dimension of the dispatch (others will be ignored)"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+};
+} // namespace
+
+static void addLinalgToSPIRVPasses(OpPassManager &pm,
+ const LinalgToSPIRVPassOptions &options) {
+ // TODO(ravishankarm): For now only evaluated with 2D tiling. So set the
+ // workgroup size and numworkgroups to size 2
+ SmallVector<int64_t, 2> numWorkGroups, workGroupSize;
+ numWorkGroups.assign(options.numWorkGroups.begin(),
+ options.numWorkGroups.end());
+ numWorkGroups.resize(2, 1);
+ workGroupSize.assign(options.workGroupSize.begin(),
+ options.workGroupSize.end());
+ workGroupSize.resize(2, 1);
+
+ // Linalg to loops.
+ pm.addPass(linalg::createLinalgTilingPass(workGroupSize));
+ pm.addPass(linalg::createConvertLinalgToLoopsPass());
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+
+ // Loops to GPU.
+ pm.addPass(createLoopToGPUPass(numWorkGroups, workGroupSize));
+ pm.addPass(createGpuKernelOutliningPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ pm.addPass(createLowerAffinePass());
+
+ // GPU to SPIR-V.
+ pm.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ pm.addPass(createConvertGPUToSPIRVPass(workGroupSize));
+
+ // SPIR-V passes for lowering attributes.
+ OpPassManager &spirvModulePM = pm.nest<spirv::ModuleOp>();
+ spirvModulePM.addPass(spirv::createLowerABIAttributesPass());
+ spirvModulePM.addPass(createCanonicalizerPass());
+ spirvModulePM.addPass(createCSEPass());
+}
+
+static PassPipelineRegistration<LinalgToSPIRVPassOptions> linalgToSPIRVPipeline(
+ "iree-linalg-to-spirv",
+ "Runs the progressive lowering pipeline from Linalg to SPIR-V",
+ [](OpPassManager &passManager, const LinalgToSPIRVPassOptions &options) {
+ addLinalgToSPIRVPasses(passManager, options);
+ });
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
new file mode 100644
index 0000000..76a179d
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/BUILD
@@ -0,0 +1,17 @@
+# Tests for common transforms.
+
+load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_setup_lit_package(
+ data = [
+ "//iree/tools:iree-opt",
+ "@llvm-project//mlir:mlir-translate",
+ ],
+)
+
+iree_glob_lit_tests()
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir
new file mode 100644
index 0000000..8d15629
--- /dev/null
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/test/single_pw_op.mlir
@@ -0,0 +1,14 @@
+// RUN: iree-opt -pass-pipeline='iree-linalg-to-spirv{workgroup-size=2,2 num-workgroups=2,2}' %s
+
+#map0 = (d0, d1) -> (d0, d1)
+
+module {
+ func @fmul(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
+ linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %0 = mulf %arg3, %arg4 : f32
+ linalg.yield %0 : f32
+ }: memref<12x4xf32>, memref<12x4xf32>, memref<12x4xf32>
+ return
+ }
+}
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index cec8d96..0843f6a 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -63,8 +63,10 @@
"//iree/compiler/Translation/Interpreter/Transforms",
"//iree/compiler/Translation:IREEVM",
"//iree/compiler/Translation/SPIRV/XLAToSPIRV",
+ "//iree/compiler/Translation/SPIRV/LinalgToSPIRV",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AffineDialectRegistration",
+ "@llvm-project//mlir:LinalgDialectRegistration",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:StandardDialectRegistration",
"@llvm-project//mlir:SPIRVDialectRegistration",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 6bb4d24..52e312f 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -89,6 +89,7 @@
# iree::base::init
# iree::base::status
# iree::compiler::Translation::Interpreter
+# iree::compiler::Translation::SPIRV::LinalgToSPIRV
# iree::compiler::Translation::SPIRV::XLAToSPIRV
# iree::compiler::Translation::IREEVM
# iree::hal::buffer_view_string_util
@@ -109,6 +110,7 @@
DEPS
${_ALWAYSLINK_LIBS}
iree::compiler::Dialect::VM::Target::Bytecode
+ iree::compiler::Translation::SPIRV::LinalgToSPIRV
iree::compiler::Translation::SPIRV::XLAToSPIRV
MLIRTranslateClParser
)