[CUDA-Codegen] Add support for constant tensor (#5966)
Add passes createTensorConstantBufferizePass to support lowering inline constant tensors.
diff --git a/iree/compiler/Conversion/Common/BUILD b/iree/compiler/Conversion/Common/BUILD
index bf7192f..3bd8795 100644
--- a/iree/compiler/Conversion/Common/BUILD
+++ b/iree/compiler/Conversion/Common/BUILD
@@ -12,18 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+load("//build_tools/bazel:tblgen.bzl", "gentbl_cc_library")
+
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
+gentbl_cc_library(
+ name = "FoldTensorExtractOpIncGen",
+ tbl_outs = [
+ (
+ ["-gen-rewriters"],
+ "FoldTensorExtractOp.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "FoldTensorExtractOp.td",
+ td_srcs = [
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:MemRefOpsTdFiles",
+ "@llvm-project//mlir:TensorOpsTdFiles",
+ ],
+)
+
cc_library(
name = "Common",
srcs = [
"BufferAllocViewCleanUpPass.cpp",
"DemoteF32ToF16.cpp",
"FlattenMemRefSubspanPass.cpp",
+ "FoldTensorExtractOpPass.cpp",
"ForOpCanonicalizationPass.cpp",
"LaunchConfig.cpp",
"LinalgBufferizePass.cpp",
@@ -39,6 +59,7 @@
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/Common:FoldTensorExtractOpIncGen",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
@@ -58,6 +79,7 @@
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
],
diff --git a/iree/compiler/Conversion/Common/CMakeLists.txt b/iree/compiler/Conversion/Common/CMakeLists.txt
index 9514518..7816f9c 100644
--- a/iree/compiler/Conversion/Common/CMakeLists.txt
+++ b/iree/compiler/Conversion/Common/CMakeLists.txt
@@ -10,6 +10,15 @@
iree_add_all_subdirs()
+iree_tablegen_library(
+ NAME
+ FoldTensorExtractOpIncGen
+ TD_FILE
+ "FoldTensorExtractOp.td"
+ OUTS
+ -gen-rewriters FoldTensorExtractOp.cpp.inc
+)
+
iree_cc_library(
NAME
Common
@@ -21,6 +30,7 @@
"BufferAllocViewCleanUpPass.cpp"
"DemoteF32ToF16.cpp"
"FlattenMemRefSubspanPass.cpp"
+ "FoldTensorExtractOpPass.cpp"
"ForOpCanonicalizationPass.cpp"
"LaunchConfig.cpp"
"LinalgBufferizePass.cpp"
@@ -43,9 +53,11 @@
MLIRSideEffectInterfaces
MLIRStandard
MLIRSupport
+ MLIRTensor
MLIRTransforms
MLIRVector
iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::Common::FoldTensorExtractOpIncGen
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.td b/iree/compiler/Conversion/Common/FoldTensorExtractOp.td
similarity index 81%
rename from iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.td
rename to iree/compiler/Conversion/Common/FoldTensorExtractOp.td
index 1f6d0b0..7312267 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.td
+++ b/iree/compiler/Conversion/Common/FoldTensorExtractOp.td
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_COMPILER_CONVERSION_LINALGTOLLVM_FOLDTENSOREXTRACTOP
-#define IREE_COMPILER_CONVERSION_LINALGTOLLVM_FOLDTENSOREXTRACTOP
+#ifndef IREE_COMPILER_CONVERSION_COMMON_FOLDTENSOREXTRACTOP
+#define IREE_COMPILER_CONVERSION_COMMON_FOLDTENSOREXTRACTOP
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
@@ -23,4 +23,4 @@
def : Pat<(Tensor_ExtractOp (TensorLoadOp $value), $indices),
(LoadOp $value, $indices)>;
-#endif // IREE_COMPILER_CONVERSION_LINALGTOLLVM_FOLDTENSOREXTRACTOP
+#endif // IREE_COMPILER_CONVERSION_COMMON_FOLDTENSOREXTRACTOP
diff --git a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp b/iree/compiler/Conversion/Common/FoldTensorExtractOpPass.cpp
similarity index 95%
rename from iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp
rename to iree/compiler/Conversion/Common/FoldTensorExtractOpPass.cpp
index 026dd95..133ef12 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp
+++ b/iree/compiler/Conversion/Common/FoldTensorExtractOpPass.cpp
@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
+#include "iree/compiler/Conversion/Common/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -23,7 +23,7 @@
namespace iree_compiler {
namespace {
-#include "iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.cpp.inc"
+#include "iree/compiler/Conversion/Common/FoldTensorExtractOp.cpp.inc"
}
namespace {
diff --git a/iree/compiler/Conversion/Common/Passes.h b/iree/compiler/Conversion/Common/Passes.h
index 9cd0624..bd29a23 100644
--- a/iree/compiler/Conversion/Common/Passes.h
+++ b/iree/compiler/Conversion/Common/Passes.h
@@ -71,5 +71,10 @@
std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
createSetNumWorkgroupsPass(ArrayRef<int64_t> workgroupSize = {});
+/// After running the upstream TensorConstantBufferize pass, remove tensor_loads
+/// introduced for use only in tensor_extract. These can be folded to use a load
+/// of the created memref object that holds the constant values.
+std::unique_ptr<OperationPass<>> createFoldTensorExtractOpPass();
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/Common/test/BUILD b/iree/compiler/Conversion/Common/test/BUILD
index 0d8f1a2..f1b956a 100644
--- a/iree/compiler/Conversion/Common/test/BUILD
+++ b/iree/compiler/Conversion/Common/test/BUILD
@@ -31,6 +31,7 @@
"canonicalize_interface_load_store.mlir",
"f32Tof16.mlir",
"flatten_memref_subspan.mlir",
+ "fold_tensor_extract_op.mlir",
"forop_canonicalization.mlir",
"linalg_bufferize.mlir",
"remove_dead_allocs.mlir",
diff --git a/iree/compiler/Conversion/Common/test/CMakeLists.txt b/iree/compiler/Conversion/Common/test/CMakeLists.txt
index 3ae07b0..25b8b05 100644
--- a/iree/compiler/Conversion/Common/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/Common/test/CMakeLists.txt
@@ -18,6 +18,7 @@
"canonicalize_interface_load_store.mlir"
"f32Tof16.mlir"
"flatten_memref_subspan.mlir"
+ "fold_tensor_extract_op.mlir"
"forop_canonicalization.mlir"
"linalg_bufferize.mlir"
"remove_dead_allocs.mlir"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/fold_tensor_extract_op.mlir b/iree/compiler/Conversion/Common/test/fold_tensor_extract_op.mlir
similarity index 100%
rename from iree/compiler/Conversion/LinalgToLLVM/test/fold_tensor_extract_op.mlir
rename to iree/compiler/Conversion/Common/test/fold_tensor_extract_op.mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index b708cc8..567c85e 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -12,36 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("//build_tools/bazel:tblgen.bzl", "gentbl_cc_library")
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
-gentbl_cc_library(
- name = "FoldTensorExtractOpIncGen",
- tbl_outs = [
- (
- ["-gen-rewriters"],
- "FoldTensorExtractOp.cpp.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "FoldTensorExtractOp.td",
- td_srcs = [
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:MemRefOpsTdFiles",
- "@llvm-project//mlir:TensorOpsTdFiles",
- ],
-)
-
cc_library(
name = "LinalgToLLVM",
srcs = [
"ConvertToLLVM.cpp",
- "FoldTensorExtractOpPass.cpp",
"KernelDispatch.cpp",
"LLVMCodeGenOptions.cpp",
"LinalgTileAndVectorizePass.cpp",
@@ -60,7 +40,6 @@
deps = [
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/Common",
- "//iree/compiler/Conversion/LinalgToLLVM:FoldTensorExtractOpIncGen",
"//iree/compiler/Conversion/VectorToLLVM",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index c42d4fc..2fdfd52 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -10,15 +10,6 @@
iree_add_all_subdirs()
-iree_tablegen_library(
- NAME
- FoldTensorExtractOpIncGen
- TD_FILE
- "FoldTensorExtractOp.td"
- OUTS
- -gen-rewriters FoldTensorExtractOp.cpp.inc
-)
-
iree_cc_library(
NAME
LinalgToLLVM
@@ -28,7 +19,6 @@
"Passes.h"
SRCS
"ConvertToLLVM.cpp"
- "FoldTensorExtractOpPass.cpp"
"KernelDispatch.cpp"
"LLVMCodeGenOptions.cpp"
"LinalgTileAndVectorizePass.cpp"
@@ -66,7 +56,6 @@
MLIRVectorToSCF
iree::compiler::Conversion::CodegenUtils
iree::compiler::Conversion::Common
- iree::compiler::Conversion::LinalgToLLVM::FoldTensorExtractOpIncGen
iree::compiler::Conversion::VectorToLLVM
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 147a650..3ce7581 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -50,11 +50,6 @@
/// Pass to convert Linalg ops into vector operations.
std::unique_ptr<FunctionPass> createLinalgVectorizePass();
-/// After running the upstream TensorConstantBufferize pass, remove tensor_loads
-/// introduced for use only in tensor_extract. These can be folded to use a load
-/// of the created memref object that holds the constant values.
-std::unique_ptr<OperationPass<>> createFoldTensorExtractOpPass();
-
//===----------------------------------------------------------------------===//
// Pass Pipelines for CPU Lowering
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
index 1f861d7..c0fc200 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
@@ -30,7 +30,6 @@
"hal_interface_bindings.mlir",
"hal_interface_constants.mlir",
"hal_interface_workgroup_info.mlir",
- "fold_tensor_extract_op.mlir",
"linalg_vectorize.mlir",
"materialize_launch_configuration.mlir",
"matmul_vectorization.mlir",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
index 811ea5a..7b0821f 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
@@ -14,7 +14,6 @@
NAME
lit
SRCS
- "fold_tensor_extract_op.mlir"
"hal_interface_bindings.mlir"
"hal_interface_constants.mlir"
"hal_interface_workgroup_info.mlir"
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD b/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
index 58f9f9d..a101363 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
@@ -57,6 +57,7 @@
"@llvm-project//mlir:ROCDLDialect",
"@llvm-project//mlir:SCFToStandard",
"@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:StandardToSPIRV",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
index 687eb59..71a1a72 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
@@ -42,6 +42,7 @@
MLIRROCDLIR
MLIRSCFToStandard
MLIRStandard
+ MLIRStandardOpsTransforms
MLIRStandardToLLVM
MLIRStandardToSPIRV
MLIRTransforms
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
index ccea7dd..80192a0 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
@@ -20,6 +20,7 @@
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
@@ -59,6 +60,10 @@
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCSEPass());
+ // Handled tensor-type constants.
+ pm.addNestedPass<ModuleOp>(createTensorConstantBufferizePass());
+ pm.addNestedPass<ModuleOp>(createFoldTensorExtractOpPass());
+
// SCF -> STD
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createLowerToCFGPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir b/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
index 3a94b8e..50659e0 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
@@ -181,3 +181,41 @@
// CHECK: lvm.fmul %{{.*}}, %{{.*}} : f32
// CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32
// CHECK: llvm.store {{.*}} : !llvm.ptr<f32>
+
+// -----
+
+hal.executable @simpleMath_ex_dispatch_0 {
+ hal.interface @io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @cuda, filter="cuda" {
+ hal.executable.entry_point @add_dispatch_0 attributes {interface = @io, ordinal = 0 : index}
+ module {
+ func @add_dispatch_0() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:16xf32>
+ %2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:16xf32>
+ %3 = linalg.init_tensor [16] : tensor<16xf32>
+ %4 = flow.dispatch.tensor.load %0, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32>
+ %5 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
+ %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<16xf32>, tensor<16xf32>) outs(%3 : tensor<16xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
+ %7 = addf %arg0, %arg1 : f32
+ linalg.yield %7 : f32
+ } -> tensor<16xf32>
+ flow.dispatch.tensor.store %6, %2, offsets=[], sizes=[], strides=[] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:16xf32>
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
+// CHECK: hal.executable.target @cuda, filter="cuda" {
+// CHECK: llvm.mlir.global private constant @{{.*}}(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01, 1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01, 1.600000e+01]> : tensor<16xf32>)
+// CHECK: llvm.fadd